!22085 [assistant][ops] add new data operator WikiTextDataset

Merge pull request !22085 from ZJUTER0126/WikiTextDataset
This commit is contained in:
i-robot 2021-12-30 03:12:03 +00:00 committed by Gitee
commit 7c241bbaf5
18 changed files with 1650 additions and 4 deletions

View File

@ -128,6 +128,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/wiki_text_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yahoo_answers_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yelp_review_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
@ -1664,7 +1665,15 @@ VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<c
auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} // namespace dataset
}
WikiTextDataset::WikiTextDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<WikiTextNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::vector<char>> &columns_list,

View File

@ -56,6 +56,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/wiki_text_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yahoo_answers_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yelp_review_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
@ -614,6 +615,18 @@ PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
(void)py::class_<WikiTextNode, DatasetNode, std::shared_ptr<WikiTextNode>>(*m, "WikiTextNode",
"to create a WikiTextNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
auto wiki_text = std::make_shared<WikiTextNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(wiki_text->ValidateParams());
return wiki_text;
}));
}));
PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
(void)py::class_<YahooAnswersNode, DatasetNode, std::shared_ptr<YahooAnswersNode>>(
*m, "YahooAnswersNode", "to create a YahooAnswersNode")
@ -647,6 +660,5 @@ PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) {
return yes_no;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -42,6 +42,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
text_file_op.cc
udpos_op.cc
usps_op.cc
wiki_text_op.cc
yahoo_answers_op.cc
yelp_review_op.cc
yes_no_op.cc

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/wiki_text_op.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
WikiTextOp::WikiTextOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, const std::vector<std::string> &file_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), file_list, op_connector_size,
shuffle_files, num_devices, device_id) {}
// A print method typically used for debugging.
void WikiTextOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nWikiText files list:\n";
for (size_t i = 0; i < text_files_list_.size(); ++i) {
out << " " << text_files_list_[i];
}
out << "\nData Schema:\n";
out << *data_schema_ << "\n\n";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_WIKI_TEXT_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_WIKI_TEXT_OP_H_
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/util/queue.h"
namespace mindspore {
namespace dataset {
class JaggedConnector;
class WikiTextOp : public TextFileOp {
public:
/// \brief Constructor.
/// \param[in] num_workers Number of workers reading images in parallel
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] worker_connector_size Size of each internal queue.
/// \param[in] data_schema Path to dataset schema file.
/// \param[in] file_list List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices Number of devices that the dataset should be divided into.
/// \param[in] device_id The device ID within num_devices. This argument should be
/// specified only when num_devices is also specified.
WikiTextOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, std::unique_ptr<DataSchema>,
const std::vector<std::string> &file_list, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id);
/// \brief Default destructor.
~WikiTextOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out he output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "WikiTextOp"; }
/// \brief DatasetName name getter.
/// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name.
/// \return DatasetName of the current Op.
std::string DatasetName(bool upper = false) const { return upper ? "WikiText" : "wiki text"; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_WIKI_TEXT_OP_H_

View File

@ -119,6 +119,7 @@ constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kUDPOSNode[] = "UDPOSDataset";
constexpr char kUSPSNode[] = "USPSDataset";
constexpr char kVOCNode[] = "VOCDataset";
constexpr char kWikiTextNode[] = "WikiTextDataset";
constexpr char kYahooAnswersNode[] = "YahooAnswersDataset";
constexpr char kYelpReviewNode[] = "YelpReviewDataset";
constexpr char kYesNoNode[] = "YesNoDataset";

View File

@ -45,6 +45,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
udpos_node.cc
usps_node.cc
voc_node.cc
wiki_text_node.cc
yahoo_answers_node.cc
yelp_review_node.cc
yes_no_node.cc

View File

@ -0,0 +1,198 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/wiki_text_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/wiki_text_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for WikiTextNode.
WikiTextNode::WikiTextNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
wikitext_files_list_(WalkAllFiles(usage, dataset_dir)) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> WikiTextNode::Copy() {
auto node =
std::make_shared<WikiTextNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void WikiTextNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status WikiTextNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("WikiTextDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("WikiTextDataset", usage_, {"train", "test", "valid", "all"}));
RETURN_IF_NOT_OK(ValidateScalar("WikiTextDataset", "num_samples", num_samples_, {0}, false));
RETURN_IF_NOT_OK(
ValidateEnum("WikiTextDataset", "ShuffleMode", shuffle_,
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal, ShuffleMode::kInfile}));
RETURN_IF_NOT_OK(ValidateDatasetShardParams("WikiTextDataset", num_shards_, shard_id_));
return Status::OK();
}
// Function to build WikiTextNode.
Status WikiTextNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order.
std::vector<std::string> sorted_dataset_files = wikitext_files_list_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// Create and initialize WikiTextNode.
std::shared_ptr<WikiTextOp> wikitext_op =
std::make_shared<WikiTextOp>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(wikitext_op->Init());
// If a global shuffle is used for WikiText, it will inject a shuffle op over the WikiText.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset WikiText's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp.
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset.
RETURN_IF_NOT_OK(WikiTextOp::CountAllFileRows(wikitext_files_list_, &num_rows));
// Add the shuffle op after this op.
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
wikitext_op->SetTotalRepeats(GetTotalRepeats());
wikitext_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add WikiTextNode.
node_ops->push_back(wikitext_op);
return Status::OK();
}
// Get the shard id of node.
Status WikiTextNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size.
Status WikiTextNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(WikiTextOp::CountAllFileRows(wikitext_files_list_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status WikiTextNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// WikiText by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status WikiTextNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this WikiText node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the WikiText node need to be reset to its defaults so
// that this WikiText node will produce the full set of data into the cache.
Status WikiTextNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
std::vector<std::string> WikiTextNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
std::vector<std::string> wikitext_files_list;
Path train_prefix("wiki.train.tokens");
Path test_prefix("wiki.test.tokens");
Path valid_prefix("wiki.valid.tokens");
Path dir(dataset_dir);
if (usage == "train") {
Path temp_path = dir / train_prefix;
wikitext_files_list.push_back(temp_path.ToString());
} else if (usage == "test") {
Path temp_path = dir / test_prefix;
wikitext_files_list.push_back(temp_path.ToString());
} else if (usage == "valid") {
Path temp_path = dir / valid_prefix;
wikitext_files_list.push_back(temp_path.ToString());
} else {
Path temp_path = dir / train_prefix;
wikitext_files_list.push_back(temp_path.ToString());
Path temp_path1 = dir / test_prefix;
wikitext_files_list.push_back(temp_path1.ToString());
Path temp_path2 = dir / valid_prefix;
wikitext_files_list.push_back(temp_path2.ToString());
}
return wikitext_files_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_WIKI_TEXT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_WIKI_TEXT_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \brief class WikiTextNode.
/// \brief Dataset derived class to represent WikiText dataset.
class WikiTextNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
WikiTextNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~WikiTextNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kWikiTextNode; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Get the num samples id of node.
/// \return NumSamples of the node.
int32_t NumSamples() const { return num_samples_; }
/// \brief Get the num shards id of node.
/// \return NumSamples of the node.
int32_t NumShards() const { return num_shards_; }
/// \brief Get the shard id of node.
/// \return Shard_id of the node.
int32_t ShardId() const { return shard_id_; }
/// \brief Get the shuffle mode of node.
/// \return Shuffle of the node.
ShuffleMode Shuffle() const { return shuffle_; }
/// \brief Get the usage node.
/// \return Usage of the node.
const std::string &Usage() const { return usage_; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief WikiText by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in
/// the tree, that cache can inherit this sampler from the leaf, providing
/// sampling support from the caching layer. That is why we setup the
/// sampler for a leaf node that does not use sampling. Note: This
/// function is common among NonMappableSourceNode and should be promoted
/// to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this WikiText node,
/// then the cache will be executing a sampler for fetching the data.
/// As such, any options in the WikiText node need to be reset to its defaults
/// so that this WikiText node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its
/// parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
/// \brief Generate a list of read file names according to usage.
/// \param[in] usage Part of dataset of WikiText.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \return std::vector<std::string> A list of read file names.
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
private:
std::string dataset_dir_;
std::string usage_;
int64_t num_samples_;
int32_t num_shards_;
int32_t shard_id_;
ShuffleMode shuffle_;
std::vector<std::string> wikitext_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_WIKI_TEXT_NODE_H_

View File

@ -4837,6 +4837,70 @@ inline std::shared_ptr<VOCDataset> MS_API VOC(const std::string &dataset_dir, co
MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}
/// \class WikiTextDataset
/// \brief A source dataset for reading and parsing WikiTextDataset dataset.
class MS_API WikiTextDataset : public Dataset {
public:
/// \brief Constructor of WikiTextDataset Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data list txt file to be read, can be "train", "test", 'valid' or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
explicit WikiTextDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
/// Destructor of WikiTextDataset.
~WikiTextDataset() = default;
};
/// \brief Function to create a WikiText Dataset.
/// \note The generated dataset has one column ['text'].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage One of "all", "train" , 'valid' or "test" (default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the WikiTextDataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/wiki_dataset_directory";
/// std::shared_ptr<Dataset> ds = WikiText(folder_path, "all");
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In WikiText dataset, each dictionary has key "text" */
/// auto text = row["image"];
/// \endcode
inline std::shared_ptr<WikiTextDataset> MS_API WikiText(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<WikiTextDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
/// \class YahooAnswersDataset
/// \brief A source dataset for reading and parsing YahooAnswers dataset.
class MS_API YahooAnswersDataset : public Dataset {

View File

@ -76,7 +76,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset, \
check_conll2000_dataset, check_amazon_review_dataset, check_semeion_dataset, check_caltech101_dataset, \
check_caltech256_dataset
check_caltech256_dataset, check_wiki_text_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -3502,7 +3502,6 @@ class FashionMnistDataset(MappableDataset):
We intend Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking
machine learning algorithms. It shares the same image size and structure of training and testing splits.
Here is the original Fashion-MNIST dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
@ -6388,6 +6387,84 @@ class USPSDataset(SourceDataset):
self.shard_id)
class WikiTextDataset(SourceDataset):
"""
A source dataset that reads and parses WikiText2 and WikiText103 datasets.
The generated dataset has one column :py:obj:`[text]`.
The tensor of column :py:obj:`text` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Acceptable usages include `train`, `test`, 'valid' and `all`(default=None, all samples).
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, 'num_samples' reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Examples:
>>> wiki_text_dataset_dir = "/path/to/wiki_text_dataset_directory"
>>> dataset = ds.WikiTextDataset(dataset_dir=wiki_text_dataset_dir, usage='all')
About WikiTextDataset dataset:
The WikiText Long Term Dependency Language Modeling Dataset is an English lexicon containing 100 million words.
These terms are drawn from Wikipedia's premium and benchmark articles, including versions of Wikitext2 and
Wikitext103. For WikiText2, it has 36718 lines in wiki.train.tokens, 4358 lines in wiki.test.tokens and
3760 lines in wiki.valid.tokens. For WikiText103, it has 1801350 lines in wiki.train.tokens, 4358 lines in
wiki.test.tokens and 3760 lines in wiki.valid.tokens.
Here is the original WikiText dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
WikiText2/WikiText103
wiki.train.tokens
wiki.test.tokens
wiki.valid.tokens
Citation:
.. code-block::
@article{merity2016pointer,
title={Pointer sentinel mixture models},
author={Merity, Stephen and Xiong, Caiming and Bradbury, James and Socher, Richard},
journal={arXiv preprint arXiv:1609.07843},
year={2016}
}
"""
@check_wiki_text_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.WikiTextNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class VOCDataset(MappableDataset):
"""
A source dataset for reading and parsing VOC dataset.

View File

@ -2411,3 +2411,32 @@ def check_semeion_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_wiki_text_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(WikiTextDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_dir
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
# check usage
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -58,6 +58,7 @@ SET(DE_UT_SRCS
c_api_dataset_yahoo_answers_test.cc
c_api_dataset_yelp_review_test.cc
c_api_dataset_yes_no_test.cc
c_api_dataset_wiki_text_test.cc
c_api_datasets_test.cc
c_api_epoch_ctrl_test.cc
c_api_pull_based_test.cc

View File

@ -0,0 +1,583 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::ShuffleMode;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test WikiText Dataset.
/// Description: read WikiText data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetBasic.";
// Test WikiText Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" I am happy "},
{" finish math homework "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: read WikiText data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetBasicWithPipeline.";
// Test WikiText Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds1 = WikiText(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = WikiText(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 15 samples
EXPECT_EQ(i, 15);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: read WikiText data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextGetters.";
// Test WikiText Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::vector<std::string> column_names = {"text"};
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
ds = WikiText(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 1);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].ToString(), "<>");
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail1.";
// Create a WikiText Dataset
// with invalid samplers=-1
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", -1, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: WikiText number of samples cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail2.";
// Attempt to create a WikiText Dataset
// with wrongful empty dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText("123", "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: dataset_dir is not specified
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail3.";
// Create a WikiText Dataset
// with non-existent dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "asd", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid usage
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail4.";
// Create a WikiText Dataset
// with empty string dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText("", "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail5) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail5.";
// Create a WikiText Dataset
// with invalid num_shards=0 value
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", 2, ShuffleMode::kFalse, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Number of shards cannot be <=0
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail6) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail6.";
// Create a WikiText Dataset
// with invalid shard_id=-1 value
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", 2, ShuffleMode::kFalse, 1, -1);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: shard_id cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetFail7) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetFail7.";
// Create a WikiText Dataset
// with invalid shard_id=2 and num_shards=2 combination
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Cannot have shard_id >= num_shards
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test WikiText Dataset.
/// Description: Test WikiTextDataset with Shuffle mode False.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetShuffleFalse) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetShuffleFalse.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(2);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "all", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" go to china "},
{" I am happy "},
{" I lova MindSpore "},
{" finish math homework "},
{" black white grapes "},
{" just ahead of them there was a huge fissure "},
{" zhejiang, china "},
{" MindSpore Ascend "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: Test WikiTextDataset with Shuffle mode Files.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetShuffleFilesA.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(654);
GlobalContext::config_manager()->set_num_parallel_workers(1);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "all", 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" go to china "},
{" I lova MindSpore "},
{" black white grapes "},
{" no it was black friday "},
{" I am happy "},
{" finish math homework "},
{" just ahead of them there was a huge fissure "},
{" zhejiang, china "},
{" MindSpore Ascend "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: Test WikiTextDataset with Shuffle mode Infile.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetShuffleFilesB.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "all", 0, ShuffleMode::kInfile);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" go to china "},
{" just ahead of them there was a huge fissure "},
{" I am happy "},
{" I lova MindSpore "},
{" zhejiang, china "},
{" finish math homework "},
{" black white grapes "},
{" MindSpore Ascend "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test WikiText Dataset.
/// Description: Test WikiTextDataset with Shuffle mode Global.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestWikiTextDatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestWikiTextDatasetShuffleGlobal.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a TextFile Dataset, with two text files
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Set shuffle to global shuffle
std::string dataset_dir = datasets_root_path_ + "/testWikiText";
std::shared_ptr<Dataset> ds = WikiText(dataset_dir, "all", 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" MindSpore Ascend "},
{" go to china "},
{" I am happy "},
{" no it was black friday "},
{" just ahead of them there was a huge fissure "},
{" zhejiang, china "},
{" finish math homework "},
{" I lova MindSpore "},
{" black white grapes "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,3 @@
no it was black friday
I am happy
finish math homework

View File

@ -0,0 +1,3 @@
go to china
I lova MindSpore
black white grapes

View File

@ -0,0 +1,3 @@
just ahead of them there was a huge fissure
zhejiang, china
MindSpore Ascend

View File

@ -0,0 +1,393 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers, config_get_set_seed
FILE_DIR = '../data/dataset/testWikiText'
def test_wiki_text_dataset_test():
"""
Feature: Test WikiText Dataset.
Description: read test data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='test', shuffle=False)
count = 0
test_content = [" no it was black friday ", " I am happy ", " finish math homework "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
strs = i["text"].item().decode("utf8")
assert strs == test_content[count]
count += 1
assert count == 3
def test_wiki_text_dataset_train():
"""
Feature: Test WikiText Dataset.
Description: read train data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='train', shuffle=False)
count = 0
train_content = [" go to china ", " I lova MindSpore ", " black white grapes "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
strs = i["text"].item().decode("utf8")
assert strs == train_content[count]
count += 1
assert count == 3
def test_wiki_text_dataset_valid():
"""
Feature: Test WikiText Dataset.
Description: read valid data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='valid', shuffle=False)
count = 0
valid_content = [" just ahead of them there was a huge fissure ", " zhejiang, china ", " MindSpore Ascend "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
strs = i["text"].item().decode("utf8")
assert strs == valid_content[count]
count += 1
assert count == 3
def test_wiki_text_dataset_all_file():
"""
Feature: Test WikiText Dataset.
Description: read data from all files.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='all')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 9
def test_wiki_text_dataset_num_samples_none():
"""
Feature: Test WikiText Dataset.
Description: read data with no num_samples input.
Expectation: the data is processed successfully.
"""
# Do not provide a num_samples argument, so it would be None by default, which means all samples are read.
data = ds.WikiTextDataset(FILE_DIR, usage='all')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 9
def test_wiki_text_dataset_shuffle_false_and_workers_4():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file with shuffle is False and num_parallel_workers=4.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(987)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=False)
count = 0
line = [" no it was black friday ",
" go to china ",
" just ahead of them there was a huge fissure ",
" I am happy ",
" I lova MindSpore ",
" zhejiang, china ",
" finish math homework ",
" black white grapes ",
" MindSpore Ascend "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_shuffle_false_and_workers_1():
"""
Feature: Test WikiText Dataset.
Description: Read data from a single file with shuffle is False and num_parallel_workers is 1.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=False)
count = 0
line = [" no it was black friday ",
" I am happy ",
" finish math homework ",
" go to china ",
" I lova MindSpore ",
" black white grapes ",
" just ahead of them there was a huge fissure ",
" zhejiang, china ",
" MindSpore Ascend "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_shuffle_files_and_workers_4():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file with shuffle is files and num_parallel_workers is 4.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(135)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
count = 0
line = [" just ahead of them there was a huge fissure ",
" go to china ",
" no it was black friday ",
" zhejiang, china ",
" I lova MindSpore ",
" I am happy ",
" MindSpore Ascend ",
" black white grapes ",
" finish math homework "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_shuffle_files_and_workers_1():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file with shuffle is files and num_parallel_workers is 1.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(135)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
count = 0
line = [" just ahead of them there was a huge fissure ",
" zhejiang, china ",
" MindSpore Ascend ",
" go to china ",
" I lova MindSpore ",
" black white grapes ",
" no it was black friday ",
" I am happy ",
" finish math homework "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_shuffle_global4():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file with shuffle is global.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(246)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
count = 0
line = [" MindSpore Ascend ",
" go to china ",
" I am happy ",
" no it was black friday ",
" just ahead of them there was a huge fissure ",
" zhejiang, china ",
" finish math homework ",
" I lova MindSpore ",
" black white grapes "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_shuffle_global1():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file with shuffle is global.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(246)
data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
count = 0
line = [" MindSpore Ascend ",
" go to china ",
" I am happy ",
" I lova MindSpore ",
" black white grapes ",
" finish math homework ",
" zhejiang, china ",
" no it was black friday ",
" just ahead of them there was a huge fissure "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_wiki_text_dataset_num_samples():
"""
Feature: Test WikiText Dataset.
Description: Test num_samples.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='all', num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_wiki_text_dataset_distribution():
"""
Feature: Test WikiText Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 5
def test_wiki_text_dataset_repeat():
"""
Feature: Test WikiText Dataset.
Description: Test repeat.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='test', shuffle=False)
data = data.repeat(3)
count = 0
line = [" no it was black friday ",
" I am happy ",
" finish math homework ",
" no it was black friday ",
" I am happy ",
" finish math homework ",
" no it was black friday ",
" I am happy ",
" finish math homework ",]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
def test_wiki_text_dataset_get_datasetsize():
"""
Feature: Test WikiText Dataset.
Description: Test get_datasetsize.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='test')
size = data.get_dataset_size()
assert size == 3
def test_wiki_text_dataset_to_device():
"""
Feature: Test WikiText Dataset.
Description: Test to_device.
Expectation: the data is processed successfully.
"""
data = ds.WikiTextDataset(FILE_DIR, usage='test')
data = data.to_device()
data.send()
def test_wiki_text_dataset_exceptions():
"""
Feature: Test WikiText Dataset.
Description: Test exceptions.
Expectation: Exception thrown to be caught
"""
with pytest.raises(ValueError) as error_info:
_ = ds.WikiTextDataset(FILE_DIR, usage='test', num_samples=-1)
assert "num_samples exceeds the boundary" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.WikiTextDataset("does/not/exist/no.txt")
assert str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.WikiTextDataset("")
assert str(error_info.value)
def exception_func(item):
raise Exception("Error occur!")
with pytest.raises(RuntimeError) as error_info:
data = ds.WikiTextDataset(FILE_DIR)
data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
if __name__ == "__main__":
test_wiki_text_dataset_test()
test_wiki_text_dataset_train()
test_wiki_text_dataset_valid()
test_wiki_text_dataset_all_file()
test_wiki_text_dataset_num_samples_none()
test_wiki_text_dataset_shuffle_false_and_workers_4()
test_wiki_text_dataset_shuffle_false_and_workers_1()
test_wiki_text_dataset_shuffle_files_and_workers_4()
test_wiki_text_dataset_shuffle_files_and_workers_1()
test_wiki_text_dataset_shuffle_global4()
test_wiki_text_dataset_shuffle_global1()
test_wiki_text_dataset_num_samples()
test_wiki_text_dataset_distribution()
test_wiki_text_dataset_repeat()
test_wiki_text_dataset_get_datasetsize()
test_wiki_text_dataset_to_device()
test_wiki_text_dataset_exceptions()