[assistant][ops]New operator implementation, include SST2Dataset
This commit is contained in:
parent
6d9fdaacc1
commit
f287b4865c
|
@ -0,0 +1,69 @@
|
|||
mindspore.dataset.SST2Dataset
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.dataset.SST2Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None)
|
||||
|
||||
读取和解析SST2数据集的源数据集。
|
||||
|
||||
数据集中train.tsv文件和dev.tsv有两列 `[sentence, label]` 。
|
||||
数据集中test.tsv文件中有一列 `[sentence]` 。
|
||||
`sentence` 列和 `label` 列的数据类型都是string。
|
||||
|
||||
参数:
|
||||
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选) - 指定数据集的子集,可取值为'train'、'test'或'dev'。
|
||||
取值为'train'时将会读取67,349个训练样本,取值为'test'时将会读取1,821个测试样本,取值为'dev'时将会读取872个样本。默认值:None,读取train中样本。
|
||||
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值:None,读取全部样本。
|
||||
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
|
||||
- **shuffle** (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值:`Shuffle.GLOBAL` 。
|
||||
如果 `shuffle` 为False,则不混洗,如果 `shuffle` 为True,等同于将 `shuffle` 设置为mindspore.dataset.Shuffle.GLOBAL。
|
||||
通过传入枚举变量设置数据混洗的模式:
|
||||
|
||||
- **Shuffle.GLOBAL**:混洗样本。
|
||||
|
||||
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
|
||||
- **ValueError** - `shard_id` 参数值错误,小于0或者大于等于 `num_shards` 。
|
||||
|
||||
**关于SST2数据集:**
|
||||
|
||||
Stanford Sentiment Treebank是一个具有完全标记解析树的语料库,可以对语言中情感的合成效果进行完整的分析。
|
||||
语料库基于Pang和Lee(2005)介绍的数据集,由11855个从电影评论中提取的句子组成。它是用斯坦福解析器解析的,
|
||||
共包含215154个来自这些解析树的独特短语,每个短语都由3个人类评委进行注释。
|
||||
|
||||
以下为原始SST2数据集的结构,您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── sst2_dataset_dir
|
||||
├── train.tsv
|
||||
├── test.tsv
|
||||
├── dev.tsv
|
||||
└── original
|
||||
|
||||
**引用:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{socher-etal-2013-recursive,
|
||||
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
|
||||
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
|
||||
Christopher D. and Ng, Andrew and Potts, Christopher},
|
||||
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
|
||||
month = oct,
|
||||
year = {2013},
|
||||
address = {Seattle, Washington, USA},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
url = {https://www.aclweb.org/anthology/D13-1170},
|
||||
pages = {1631--1642},
|
||||
}
|
||||
|
||||
include:: mindspore.dataset.api_list_vision.rst
|
|
@ -149,6 +149,7 @@ mindspore.dataset
|
|||
mindspore.dataset.IWSLT2017Dataset
|
||||
mindspore.dataset.PennTreebankDataset
|
||||
mindspore.dataset.SogouNewsDataset
|
||||
mindspore.dataset.SST2Dataset
|
||||
mindspore.dataset.TextFileDataset
|
||||
mindspore.dataset.UDPOSDataset
|
||||
mindspore.dataset.WikiTextDataset
|
||||
|
|
|
@ -62,6 +62,7 @@ Text
|
|||
mindspore.dataset.IWSLT2017Dataset
|
||||
mindspore.dataset.PennTreebankDataset
|
||||
mindspore.dataset.SogouNewsDataset
|
||||
mindspore.dataset.SST2Dataset
|
||||
mindspore.dataset.TextFileDataset
|
||||
mindspore.dataset.UDPOSDataset
|
||||
mindspore.dataset.WikiTextDataset
|
||||
|
|
|
@ -121,6 +121,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
|
@ -1856,6 +1857,14 @@ SQuADDataset::SQuADDataset(const std::vector<char> &dataset_dir, const std::vect
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SST2Dataset::SST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto ds = std::make_shared<SST2Node>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
|
||||
shard_id, cache);
|
||||
ir_node_ = std::static_pointer_cast<SST2Node>(ds);
|
||||
}
|
||||
|
||||
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
|
||||
const std::vector<char> &usage, const std::vector<char> &extensions,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -59,6 +59,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
|
@ -697,6 +698,18 @@ PYBIND_REGISTER(SQuADNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SST2Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SST2Node, DatasetNode, std::shared_ptr<SST2Node>>(*m, "SST2Node",
|
||||
"to create a SST2Node")
|
||||
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||
auto sst2 = std::make_shared<SST2Node>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, nullptr);
|
||||
THROW_IF_ERROR(sst2->ValidateParams());
|
||||
return sst2;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node",
|
||||
"to create a STL10Node")
|
||||
|
|
|
@ -49,6 +49,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
sogou_news_op.cc
|
||||
speech_commands_op.cc
|
||||
squad_op.cc
|
||||
sst2_op.cc
|
||||
stl10_op.cc
|
||||
sun397_op.cc
|
||||
tedlium_op.cc
|
||||
|
|
|
@ -225,7 +225,7 @@ class CsvOp : public NonMappableLeafOp {
|
|||
/// @param str - the input string
|
||||
/// @param str - the delimiter
|
||||
/// @return - the a string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
virtual std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
// Private function for analysing the column name in every CSV file
|
||||
// @return - Status
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sst2_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "include/common/debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SST2Op::SST2Op(const std::vector<std::string> &dataset_files_list, const std::string &usage, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, int32_t num_workers, int64_t num_samples,
|
||||
int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
|
||||
int32_t device_id)
|
||||
: CsvOp(dataset_files_list, field_delim, column_default, column_name, num_workers, num_samples,
|
||||
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id),
|
||||
usage_(usage) {}
|
||||
|
||||
void SST2Op::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nST2 files list:\n";
|
||||
for (int i = 0; i < csv_files_list_.size(); ++i) {
|
||||
out << " " << csv_files_list_[i];
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> SST2Op::split(const std::string &s, char delim) {
|
||||
std::vector<std::string> res;
|
||||
std::stringstream ss(s);
|
||||
std::string item;
|
||||
bool skip = usage_ == "test";
|
||||
while (getline(ss, item, delim)) {
|
||||
if (skip) {
|
||||
skip = false;
|
||||
} else {
|
||||
res.push_back(item);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Status SST2Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
CsvParser csv_parser(worker_id, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
|
||||
RETURN_IF_NOT_OK(csv_parser.InitCsvParser());
|
||||
csv_parser.SetStartOffset(start_offset);
|
||||
csv_parser.SetEndOffset(end_offset);
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(file.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
|
||||
}
|
||||
|
||||
std::ifstream ifs;
|
||||
ifs.open(realpath.value(), std::ifstream::in);
|
||||
if (!ifs.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + file + ", the file is damaged or permission denied.");
|
||||
}
|
||||
if (column_name_list_.empty()) {
|
||||
std::string tmp;
|
||||
getline(ifs, tmp);
|
||||
}
|
||||
bool skip = usage_ == "test";
|
||||
csv_parser.Reset();
|
||||
try {
|
||||
while (ifs.good()) {
|
||||
// when ifstream reaches the end of file, the function get() return std::char_traits<char>::eof()
|
||||
// which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use
|
||||
// int to receive its return value.
|
||||
int chr = ifs.get();
|
||||
if (skip) {
|
||||
if (chr == field_delim_) {
|
||||
skip = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (usage_ == "test" && chr == '\n') {
|
||||
skip = true;
|
||||
}
|
||||
int err = csv_parser.ProcessMessage(chr);
|
||||
if (err != 0) {
|
||||
// if error code is -2, the returned error is interrupted
|
||||
if (err == -2) {
|
||||
return Status(kMDInterrupted);
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse csv file: " + file + " at line " +
|
||||
std::to_string(csv_parser.GetTotalRows() + 1) +
|
||||
". Error message: " + csv_parser.GetErrorMessage());
|
||||
}
|
||||
}
|
||||
} catch (std::invalid_argument &ia) {
|
||||
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid csv, csv file: " + file + " parse failed at line " + err_row +
|
||||
", type does not match.");
|
||||
} catch (std::out_of_range &oor) {
|
||||
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid csv, " + file + " parse failed at line " + err_row + " : value out of range.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_
|
||||
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class SST2Op : public CsvOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] dataset_files_list List of file paths for the dataset files.
|
||||
/// \param[in] usage Usage of this dataset, can be 'train', 'test' or 'all'.
|
||||
/// \param[in] field_delim A char that indicates the delimiter to separate fields.
|
||||
/// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is
|
||||
/// either a valid type (float, int, or string).
|
||||
/// \param[in] column_name List of column names of the dataset file.
|
||||
/// \param[in] num_workers Num of workers reading files in parallel.
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// \param[in] worker_connector_size Size of each internal queue.
|
||||
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
|
||||
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
|
||||
/// \param[in] num_devices Number of devices that the dataset should be divided into.
|
||||
/// \param[in] device_id The device ID within num_devices.
|
||||
SST2Op(const std::vector<std::string> &dataset_files_list, const std::string &usage, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
|
||||
int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SST2Op() override = default;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param[out] out The output stream to write output to
|
||||
/// \param[in] show_all A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief DatasetName name getter.
|
||||
/// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name.
|
||||
/// \return DatasetName of the current Op.
|
||||
std::string DatasetName(bool upper = false) const { return upper ? "SST2" : "sst2"; }
|
||||
|
||||
/// \brief Op name getter
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "SST2Op"; }
|
||||
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// @param[in] s The input string
|
||||
/// @param[in] delim The delimiter
|
||||
/// @return The a string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
/// \brief Reads a csv file and loads the data into multiple tensors.
|
||||
/// @param[in] file The file to read.
|
||||
/// @param[in] start_offset The start offset of file.
|
||||
/// @param[in] end_offset The end offset of file.
|
||||
/// @param[in] worker_id The id of the worker that is executing this function.
|
||||
/// @return Status The error code returned.
|
||||
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
std::string usage_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_
|
|
@ -123,6 +123,7 @@ constexpr char kSemeionNode[] = "SemeionDataset";
|
|||
constexpr char kSogouNewsNode[] = "SogouNewsDataset";
|
||||
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
|
||||
constexpr char kSQuADNode[] = "SQuADDataset";
|
||||
constexpr char kSST2Node[] = "SST2Dataset";
|
||||
constexpr char kSTL10Node[] = "STL10Dataset";
|
||||
constexpr char kSUN397Node[] = "SUN397Dataset";
|
||||
constexpr char kTedliumNode[] = "TedliumDataset";
|
||||
|
|
|
@ -50,6 +50,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
sogou_news_node.cc
|
||||
speech_commands_node.cc
|
||||
squad_node.cc
|
||||
sst2_node.cc
|
||||
stl10_node.cc
|
||||
sun397_node.cc
|
||||
tedlium_node.cc
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SST2Node::SST2Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {
|
||||
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
|
||||
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
|
||||
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
|
||||
// PreBuildSampler is phased out, this can be cleaned up.
|
||||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> SST2Node::Copy() {
|
||||
auto node = std::make_shared<SST2Node>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
|
||||
(void)node->SetNumWorkers(num_workers_);
|
||||
(void)node->SetConnectorQueueSize(connector_que_size_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void SST2Node::Print(std::ostream &out) const {
|
||||
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
|
||||
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
|
||||
}
|
||||
|
||||
Status SST2Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SST2Node", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("SST2Node", usage_, {"dev", "train", "test"}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("SST2Node", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("SST2Node", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build SST2Node
|
||||
Status SST2Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node_ops);
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files;
|
||||
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &sorted_dataset_files));
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
char field_delim = '\t';
|
||||
|
||||
std::vector<std::string> column_names;
|
||||
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
|
||||
|
||||
std::shared_ptr<SST2Op> sst2_op = std::make_shared<SST2Op>(
|
||||
sorted_dataset_files, usage_, field_delim, column_default_list, column_names, num_workers_, num_samples_,
|
||||
worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
RETURN_IF_NOT_OK(sst2_op->Init());
|
||||
|
||||
// If a global shuffle is used for SST2, it will inject a shuffle op over the SST2.
|
||||
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
|
||||
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset SST2's shuffle
|
||||
// option to false.
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_IF_NOT_OK(SST2Op::CountAllFileRows(sorted_dataset_files, column_names.empty(), &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(
|
||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
sst2_op->SetTotalRepeats(GetTotalRepeats());
|
||||
sst2_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(sst2_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SST2Node::WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
|
||||
std::vector<std::string> *dataset_files) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_files);
|
||||
Path train_file_name("train.tsv");
|
||||
Path test_file_name("test.tsv");
|
||||
Path dev_file_name("dev.tsv");
|
||||
Path dir(dataset_dir);
|
||||
if (usage == "train") {
|
||||
Path file_path = dir / train_file_name;
|
||||
dataset_files->push_back(file_path.ToString());
|
||||
} else if (usage == "test") {
|
||||
Path file_path = dir / test_file_name;
|
||||
dataset_files->push_back(file_path.ToString());
|
||||
} else if (usage == "dev") {
|
||||
Path file_path = dir / dev_file_name;
|
||||
dataset_files->push_back(file_path.ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status SST2Node::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = shard_id_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status SST2Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_size);
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
std::vector<std::string> dataset_files;
|
||||
|
||||
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &dataset_files));
|
||||
RETURN_IF_NOT_OK(SST2Op::CountAllFileRows(dataset_files, true, &num_rows));
|
||||
sample_size = num_samples_;
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SST2Node::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["connector_queue_size"] = connector_que_size_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// SST2 by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status SST2Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this SST2 node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the SST2 node need to be reset to its defaults so
|
||||
// that this SST2 node will produce the full set of data into the cache.
|
||||
Status SST2Node::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sst2_op.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class SST2Node : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
SST2Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SST2Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kSST2Node; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Generate a list of read file names according to usage.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of SST2.
|
||||
/// \param[in] dataset_files List of filepaths for the dataset files
|
||||
/// \return std::vector<std::string> A list of read file names.
|
||||
Status WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
|
||||
std::vector<std::string> *dataset_files);
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size The size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
int64_t NumSamples() const { return num_samples_; }
|
||||
ShuffleMode Shuffle() const { return shuffle_; }
|
||||
int32_t NumShards() const { return num_shards_; }
|
||||
int32_t ShardId() const { return shard_id_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief SST2 by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup.
|
||||
/// \return Status of the function.
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this SST2 node, then the cache will be
|
||||
/// executing a sampler for fetching the data. As such, any options in the SST2 node need to be reset
|
||||
/// to its defaults so that this SST2 node will produce the full set of data into the cache. Note:
|
||||
/// This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function.
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_
|
|
@ -5237,6 +5237,64 @@ inline std::shared_ptr<SQuADDataset> DATASET_API SQuAD(const std::string &datase
|
|||
num_shards, shard_id, cache);
|
||||
}
|
||||
|
||||
/// \class SST2Dataset
|
||||
/// \brief A source dataset for reading and parsing SST2 dataset.
|
||||
class DATASET_API SST2Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of SST2Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of SST2, can be "train", "test" or "dev".
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kGlobal - Shuffle the samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
SST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of SST2.
|
||||
~SST2Dataset() override = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a SST2Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of SST2, can be "train", "test" or "dev". Default: "train".
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// Default: 0, means all samples.
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch. Default: ShuffleMode::kGlobal.
|
||||
/// Can be any of:
|
||||
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||
/// ShuffleMode::kFiles - Shuffle files only.
|
||||
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into. Default: 1.
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified. Default: 0.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the SST2Dataset
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and MindData object */
|
||||
/// std::string folder_path = "/path/to/sst2_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = SST2(folder_path, "train");
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
/// \endcode
|
||||
inline std::shared_ptr<SST2Dataset> DATASET_API SST2(const std::string &dataset_dir, const std::string &usage = "train",
|
||||
int64_t num_samples = 0,
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
||||
int32_t shard_id = 0,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
|
||||
shard_id, cache);
|
||||
}
|
||||
|
||||
/// \class STL10Dataset
|
||||
/// \brief A source dataset that reads and parses STL10 dataset.
|
||||
class DATASET_API STL10Dataset : public Dataset {
|
||||
|
|
|
@ -68,6 +68,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class SBUDataset;
|
||||
friend class SemeionDataset;
|
||||
friend class SpeechCommandsDataset;
|
||||
friend class SST2Dataset;
|
||||
friend class STL10Dataset;
|
||||
friend class SUN397Dataset;
|
||||
friend class TedliumDataset;
|
||||
|
|
|
@ -85,6 +85,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"PennTreebankDataset", # Text
|
||||
"SogouNewsDataset", # Text
|
||||
"SQuADDataset", # Text
|
||||
"SST2Dataset", # Text
|
||||
"TextFileDataset", # Text
|
||||
"UDPOSDataset", # Text
|
||||
"WikiTextDataset", # Text
|
||||
|
|
|
@ -30,7 +30,8 @@ from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt
|
|||
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
|
||||
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
|
||||
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
|
||||
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset, check_squad_dataset
|
||||
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset, check_squad_dataset, \
|
||||
check_sst2_dataset
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
|
||||
|
@ -1496,6 +1497,105 @@ class SQuADDataset(SourceDataset, TextBaseDataset):
|
|||
self.num_shards, self.shard_id)
|
||||
|
||||
|
||||
class SST2Dataset(SourceDataset, TextBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses the SST2 dataset.
|
||||
|
||||
The generated dataset's train.tsv and dev.tsv have two columns :py:obj:`[sentence, label]` .
|
||||
The generated dataset's test.tsv has one column :py:obj:`[sentence]` .
|
||||
The tensor of column :py:obj:`sentence` and :py:obj:`label` are of the string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be `train`, `test` or `dev`. `train` will read
|
||||
from 67,349 train samples, `test` will read from 1,821 test samples, `dev` will read from
|
||||
all 872 samples. Default: None, will read train samples.
|
||||
num_samples (int, optional): The number of samples to be included in the dataset.
|
||||
Default: None, will include all text.
|
||||
num_parallel_workers (int, optional): Number of workers to read the data.
|
||||
Default: None, number set in the mindspore.dataset.config.
|
||||
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch.
|
||||
Bool type and Shuffle enum are both supported to pass in. Default: `Shuffle.GLOBAL` .
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Set the mode of data shuffling by passing in enumeration variables:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle the samples.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into. Default: None.
|
||||
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards. This argument can only be specified when
|
||||
num_shards is also specified. Default: None.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
|
||||
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
|
||||
Default: None, which means no cache is used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `dataset_dir` does not contain data files.
|
||||
RuntimeError: If `num_parallel_workers` exceeds the max thread numbers.
|
||||
RuntimeError: If `num_shards` is specified but shard_id is None.
|
||||
RuntimeError: If `shard_id` is specified but num_shards is None.
|
||||
ValueError: If `shard_id` is invalid (< 0 or >= num_shards).
|
||||
|
||||
Examples:
|
||||
>>> sst2_dataset_dir = "/path/to/sst2_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Read 3 samples from SST2 dataset
|
||||
>>> dataset = ds.SST2Dataset(dataset_dir=sst2_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # 2) Read train samples from SST2 dataset
|
||||
>>> dataset = ds.SST2Dataset(dataset_dir=sst2_dataset_dir, usage="train")
|
||||
|
||||
About SST2 dataset:
|
||||
The Stanford Sentiment Treebank is a corpus with fully labeled parse trees that allows for a complete
|
||||
analysis of the compositional effects of sentiment in language. The corpus is based on the dataset introduced
|
||||
by Pang and Lee (2005) and consists of 11,855 single sentences extracted from movie reviews. It was parsed
|
||||
with the Stanford parser and includes a total of 215,154 unique phrases from those parse trees, each
|
||||
annotated by 3 human judges.
|
||||
|
||||
Here is the original SST2 dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by Mindspore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── sst2_dataset_dir
|
||||
├── train.tsv
|
||||
├── test.tsv
|
||||
├── dev.tsv
|
||||
└── original
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{socher-etal-2013-recursive,
|
||||
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
|
||||
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
|
||||
Christopher D. and Ng, Andrew and Potts, Christopher},
|
||||
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
|
||||
month = oct,
|
||||
year = {2013},
|
||||
address = {Seattle, Washington, USA},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
url = {https://www.aclweb.org/anthology/D13-1170},
|
||||
pages = {1631--1642},
|
||||
}
|
||||
"""
|
||||
|
||||
@check_sst2_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
|
||||
num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "train")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.SST2Node(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag,
|
||||
self.num_shards, self.shard_id)
|
||||
|
||||
|
||||
class TextFileDataset(SourceDataset, TextBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in text format.
|
||||
|
|
|
@ -2819,6 +2819,34 @@ def check_svhn_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_sst2_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original SST2 Dataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "dev"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_stl10_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(STL10Dataset)."""
|
||||
|
||||
|
|
|
@ -0,0 +1,524 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Read test data
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetBasic.";
|
||||
|
||||
// Create a SST2 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterator the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_EQ(row.find("index"), row.end());
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"test read SST2dataset 1 ."},
|
||||
{"test read SST2dataset 2 ."},
|
||||
{"test read SST2dataset 3 ."}};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Read train data and test data
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetUsageTrain) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetUsageTrain.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence", "label"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterator the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"train read SST2Dataset 1 . ","0"},
|
||||
{"train read SST2Dataset 2 . ","1"},
|
||||
{"train read SST2Dataset 3 . ","1"},
|
||||
{"train read SST2Dataset 4 . ","1"},
|
||||
{"train read SST2Dataset 5 . ","0"}};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Includes tests for shape, type, size
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetGetters.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kFalse);
|
||||
std::vector<std::string> column_names = {"sentence"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
EXPECT_EQ(types.size(), 1);
|
||||
EXPECT_EQ(types[0].ToString(), "string");
|
||||
|
||||
EXPECT_EQ(shapes.size(), 1);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<>");
|
||||
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Read 2 samples from train file
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetNumSamples) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetNumSamples.";
|
||||
|
||||
// Create a SST2Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence", "label"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 2, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"train read SST2Dataset 1 . ","0"},
|
||||
{"train read SST2Dataset 2 . ","1"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Test in a distributed state
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetDistribution) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetDistribution.";
|
||||
|
||||
// Create a SST2Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence", "label"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 2, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"train read SST2Dataset 1 . ","0"},
|
||||
{"train read SST2Dataset 2 . ","1"},
|
||||
{"train read SST2Dataset 3 . ","1"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Test with invalid input
|
||||
/// Expectation: Throw error messages when certain errors occur
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetFail.";
|
||||
// Create a SST2 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::string invalid_folder_path = "./NotExistPath";
|
||||
std::vector<std::string> column_names = {"sentence", "label"};
|
||||
|
||||
// Test invalid folder_path
|
||||
std::shared_ptr<Dataset> ds0 = SST2(invalid_folder_path, "train", -1, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid SST2 input
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
|
||||
// Test invalid usage
|
||||
std::shared_ptr<Dataset> ds1 = SST2(folder_path, "all", 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid SST2 input
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
// Test invalid num_samples < -1
|
||||
std::shared_ptr<Dataset> ds2 = SST2(folder_path, "train", -1, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid SST2 input
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
|
||||
// Test invalid num_shards < 1
|
||||
std::shared_ptr<Dataset> ds3 = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 0);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid SST2 input
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
|
||||
// Test invalid shard_id >= num_shards
|
||||
std::shared_ptr<Dataset> ds4 = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 2, 2);
|
||||
EXPECT_NE(ds4, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
|
||||
// Expect failure: invalid SST2 input
|
||||
EXPECT_EQ(iter4, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Read data with pipeline from test file
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetWithPipeline.";
|
||||
|
||||
// Create two SST2 Dataset, with single SST2 file
|
||||
std::string dataset_dir = datasets_root_path_ + "/testSST2/";
|
||||
|
||||
std::shared_ptr<Dataset> ds1 = SST2(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||
std::shared_ptr<Dataset> ds2 = SST2(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 3;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"sentence"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["sentence"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
// Expect 2 × 3 + 3 × 3 = 15 samples
|
||||
EXPECT_EQ(i, 15);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Test with shuffle files
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleFilesA) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesA.";
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence", "label"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"train read SST2Dataset 1 . ","0"},
|
||||
{"train read SST2Dataset 2 . ","1"},
|
||||
{"train read SST2Dataset 3 . ","1"},
|
||||
{"train read SST2Dataset 4 . ","1"},
|
||||
{"train read SST2Dataset 5 . ","0"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Test with shuffle in file
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleFilesB) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesB.";
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kInfile);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"test read SST2dataset 1 ."},
|
||||
{"test read SST2dataset 2 ."},
|
||||
{"test read SST2dataset 3 ."}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
// Expect 3 samples
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
/// Feature: SST2Dataset
|
||||
/// Description: Test with global shuffle
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleGlobal) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesGlobal.";
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testSST2/";
|
||||
std::vector<std::string> column_names = {"sentence"};
|
||||
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kGlobal);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"test read SST2dataset 1 ."},
|
||||
{"test read SST2dataset 3 ."},
|
||||
{"test read SST2dataset 2 ."}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
// Expect 3 samples
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
sentence label
|
||||
dev read SST2Dataset 1 . 0
|
||||
dev read SST2Dataset 2 . 1
|
||||
dev read SST2Dataset 3 . 1
|
||||
dev read SST2Dataset 4 . 1
|
|
|
@ -0,0 +1,4 @@
|
|||
index sentence
|
||||
0 test read SST2dataset 1 .
|
||||
1 test read SST2dataset 2 .
|
||||
2 test read SST2dataset 3 .
|
|
|
@ -0,0 +1,6 @@
|
|||
sentence label
|
||||
train read SST2Dataset 1 . 0
|
||||
train read SST2Dataset 2 . 1
|
||||
train read SST2Dataset 3 . 1
|
||||
train read SST2Dataset 4 . 1
|
||||
train read SST2Dataset 5 . 0
|
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test SST2 dataset operators
|
||||
"""
|
||||
import mindspore.dataset as ds
|
||||
|
||||
DATA_DIR = '../data/dataset/testSST2/'
|
||||
|
||||
|
||||
def test_sst2_dataset_basic():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Read data from train file
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
buffer = []
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="train", shuffle=False)
|
||||
data = data.repeat(2)
|
||||
data = data.skip(3)
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 7
|
||||
|
||||
|
||||
def test_sst2_dataset_quoted():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Read the data and compare it to expectations
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||
buffer = []
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.extend([d['sentence']])
|
||||
assert buffer == ["test read SST2dataset 1 .",
|
||||
"test read SST2dataset 2 .",
|
||||
"test read SST2dataset 3 ."]
|
||||
|
||||
|
||||
def test_sst2_dataset_usage():
|
||||
"""
|
||||
Feature: SST2Dataset.
|
||||
Description: Tead all files with usage all.
|
||||
Expectation: The data is processed successfully.
|
||||
"""
|
||||
buffer = []
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="dev", shuffle=False)
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 4
|
||||
|
||||
|
||||
def test_sst2_dataset_get_dataset_size():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Test get_dataset_size function
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="dev", shuffle=False)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 4
|
||||
|
||||
|
||||
def test_sst2_dataset_distribution():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Test in a distributed state
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="train", shuffle=False, num_shards=2, shard_id=0)
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == 3
|
||||
|
||||
|
||||
def test_sst2_dataset_num_samples():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Test num_samples parameter
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_sst2_dataset_exception():
|
||||
"""
|
||||
Feature: SST2Dataset
|
||||
Description: Test the wrong input
|
||||
Expectation: Unable to read data properly
|
||||
"""
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
try:
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["sentence"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
|
||||
try:
|
||||
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["sentence"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sst2_dataset_basic()
|
||||
test_sst2_dataset_quoted()
|
||||
test_sst2_dataset_usage()
|
||||
test_sst2_dataset_get_dataset_size()
|
||||
test_sst2_dataset_distribution()
|
||||
test_sst2_dataset_num_samples()
|
||||
test_sst2_dataset_exception()
|
Loading…
Reference in New Issue