!22348 [assistant][ops] Add new loader EnWik9Dataset

Merge pull request !22348 from 杨旭华/EnWik9Dataset
This commit is contained in:
i-robot 2022-01-05 07:36:48 +00:00 committed by Gitee
commit a1720c210d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 1412 additions and 3 deletions

View File

@ -102,6 +102,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
@ -1151,6 +1152,12 @@ EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::ve
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
EnWik9Dataset::EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<EnWik9Node>(CharToString(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
int32_t base_seed, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -39,6 +39,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
@ -248,6 +249,18 @@ PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
(void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node",
"to create an EnWik9Node")
.def(py::init([](std::string dataset_dir, int32_t num_samples, int32_t shuffle, int32_t num_shards,
int32_t shard_id) {
std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>(
dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(en_wik9->ValidateParams());
return en_wik9;
}));
}));
PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
(void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>(
*m, "FakeImageNode", "to create a FakeImageNode")

View File

@ -17,6 +17,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
dbpedia_op.cc
div2k_op.cc
emnist_op.cc
en_wik9_op.cc
fake_image_op.cc
fashion_mnist_op.cc
flickr_op.cc

View File

@ -0,0 +1,118 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/en_wik9_op.h"
#include <fstream>
#include <utility>
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
EnWik9Op::EnWik9Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> data_schema, const std::vector<std::string> &file_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: TextFileOp(num_workers, total_rows, worker_connector_size, std::move(data_schema), file_list, op_connector_size,
shuffle_files, num_devices, device_id) {}
// A print method typically used for debugging.
void EnWik9Op::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nEnWik9 file path:\n";
for (size_t i = 0; i < text_files_list_.size(); ++i) {
// Print the name of per file path.
out << " " << text_files_list_[i];
}
out << "\nData Schema:\n";
out << *data_schema_ << "\n\n";
}
}
Status EnWik9Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
auto realpath = FileUtils::GetRealPath(file.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
}
std::ifstream handle(realpath.value());
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file +
". Check if the file is damaged or permission denied.");
}
int64_t rows_total = 0;
std::string line;
while (getline(handle, line)) {
if (line.empty()) {
line = "";
}
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}
TensorRow tRow(1, nullptr);
tRow.setPath({file});
RETURN_IF_NOT_OK(LoadTensor(line, &tRow));
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
rows_total++;
}
return Status::OK();
}
int64_t EnWik9Op::CountTotalRows(const std::string &file) {
auto realpath = FileUtils::GetRealPath(file.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file, " << file << " does not exist.";
return 0;
}
std::ifstream handle(realpath.value());
if (!handle.is_open()) {
MS_LOG(ERROR) << "Invalid file, failed to open file: " << file
<< ". Check if the file is damaged or permission denied.";
return 0;
}
std::string line;
int64_t count = 0;
while (getline(handle, line)) {
count++;
}
return count;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
namespace mindspore {
namespace dataset {
class EnWik9Op : public TextFileOp {
public:
/// \brief Constructor.
/// \param[in] num_workers The number of worker threads reading data from enwiki files.
/// \param[in] total_rows The number of rows to read.
/// \param[in] worker_connector_size Size of each internal queue.
/// \param[in] data_schema The data schema object.
/// \param[in] files_list List of file paths for the dataset files.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices The number of devices.
/// \param[in] device_id Id of device.
EnWik9Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> data_schema, const std::vector<std::string> &file_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
/// \brief Default destructor.
~EnWik9Op() = default;
/// \brief A print method typically used for debugging.
/// \param[out] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "EnWik9Op"; }
/// \brief DatasetName name getter.
/// \param[in] upper A bool to control if you need upper DatasetName.
/// \return DatasetName of the current Op.
virtual std::string DatasetName(bool upper = false) const { return upper ? "EnWik9" : "enwik9"; }
/// \brief Reads a text file and loads the data into multiple TensorRows.
/// \param[in] file The file to read.
/// \param[in] start_offset - the start offset of file.
/// \param[in] end_offset - the end offset of file.
/// \param[in] The id of the worker that is executing this function.
/// \return Status The error code returned.
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
private:
/// \brief Count number of rows in each file.
/// \param[in] file Txt file name.
/// \return int64_t The total number of rows in file.
int64_t CountTotalRows(const std::string &file);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_

View File

@ -248,6 +248,5 @@ Status TextFileOp::ComputeColMap() {
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -114,7 +114,7 @@ class TextFileOp : public NonMappableLeafOp {
// Count number of rows in each file.
// @param file - txt file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
virtual int64_t CountTotalRows(const std::string &file);
std::vector<std::string> text_files_list_;
std::unique_ptr<DataSchema> data_schema_;

View File

@ -91,6 +91,7 @@ constexpr char kCSVNode[] = "CSVDataset";
constexpr char kDBpediaNode[] = "DBpediaDataset";
constexpr char kDIV2KNode[] = "DIV2KDataset";
constexpr char kEMnistNode[] = "EMnistDataset";
constexpr char kEnWik9Node[] = "EnWik9Dataset";
constexpr char kFakeImageNode[] = "FakeImageDataset";
constexpr char kFashionMnistNode[] = "FashionMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";

View File

@ -18,6 +18,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
dbpedia_node.cc
div2k_node.cc
emnist_node.cc
en_wik9_node.cc
fake_image_node.cc
fashion_mnist_node.cc
flickr_node.cc

View File

@ -0,0 +1,174 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/en_wik9_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for EnWik9Node
EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
dataset_dir_(dataset_dir) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
DirToPath(dataset_dir_);
}
std::shared_ptr<DatasetNode> EnWik9Node::Copy() {
auto node = std::make_shared<EnWik9Node>(dataset_dir_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void EnWik9Node::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status EnWik9Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("EnWik9Dataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateEnum("EnWik9Dataset", "ShuffleMode", shuffle_,
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
RETURN_IF_NOT_OK(ValidateScalar("EnWik9Dataset", "num_samples", num_samples_, {0}, false));
RETURN_IF_NOT_OK(ValidateDatasetShardParams("EnWik9Dataset", num_shards_, shard_id_));
return Status::OK();
}
void EnWik9Node::DirToPath(const std::string &dataset_dir) {
Path train_prefix("enwik9");
Path dir(dataset_dir);
Path temp_path = dir / train_prefix;
src_target_file_list_.push_back(temp_path.ToString());
}
// Function to build EnWik9Node
Status EnWik9Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// Create and initialize EnWik9Op
std::shared_ptr<EnWik9Op> en_wik9_op =
std::make_shared<EnWik9Op>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
src_target_file_list_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(en_wik9_op->Init());
// If a global shuffle is used for EnWik9, it will inject a shuffle op over the EnWik9.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset EnWik9's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(src_target_file_list_.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
en_wik9_op->SetTotalRepeats(GetTotalRepeats());
en_wik9_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add EnWik9Op
node_ops->push_back(en_wik9_op);
return Status::OK();
}
// Get the shard id of node
Status EnWik9Node::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size
Status EnWik9Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status EnWik9Node::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// EnWik9 by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status EnWik9Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this EnWik9 node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the EnWik9 node need to be reset to its defaults so
// that this EnWik9 node will produce the full set of data into the cache.
Status EnWik9Node::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,136 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class EnWik9Node.
/// \brief A Dataset derived class to represent EnWik9 dataset.
class EnWik9Node : public NonMappableSourceNode {
public:
/// \brief Constructor.
/// \param[in] dataset_dir The directory of dataset.
/// \param[in] num_samples The number of samples that users want to get.
/// \param[in] shuffle Decide the dataset shuffle pattern.
/// \param[in] num_shards The number of shards that users want to part.
/// \param[in] shard_id The id of shard.
/// \param[in] cache Tensor cache to use.
EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~EnWik9Node() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kEnWik9Node; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id Id of this shard.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
/// \return Directory of dataset.
const std::string &DatasetDir() const { return dataset_dir_; }
// \brief Getter functions.
/// \return The number of samples.
int32_t NumSamples() const { return num_samples_; }
// \brief Getter functions.
/// \return The number of shards.
int32_t NumShards() const { return num_shards_; }
// \brief Getter functions.
/// \return Id of shard.
int32_t ShardId() const { return shard_id_; }
// \brief Getter functions.
/// \return Shuffle pattern.
ShuffleMode Shuffle() const { return shuffle_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief EnWik9 by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this EnWik9 node, then the cache will be executing.
/// a sampler for fetching the data. As such, any options in the EnWik9 node need to be reset to its defaults
/// so that this EnWik9 node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
/// \brief Change file's directory into file's path, and put it into a list.
/// \param[in] dataset_dir Directory of enwik9 dataset.
/// \return A list of read file names.
void DirToPath(const std::string &dataset_dir);
private:
std::string dataset_dir_; // dataset of file.
int32_t num_samples_; // the number of samples.
int32_t num_shards_; // the number of shards.
int32_t shard_id_; // the id of shard.
ShuffleMode shuffle_; // a object of ShuffleMode, which belongs to num.
std::vector<std::string> src_target_file_list_; // file list;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_

View File

@ -2342,6 +2342,51 @@ inline std::shared_ptr<EMnistDataset> MS_API EMnist(const std::string &dataset_d
cache);
}
/// \class EnWik9Dataset
/// \brief A source dataset for reading and parsing EnWik9 dataset.
class MS_API EnWik9Dataset : public Dataset {
public:
/// \brief Function to create a EnWik9Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// Destructor of EnWik9Dataset.
~EnWik9Dataset() = default;
};
/// \brief Function to create a EnWik9Dataset.
/// \note The generated dataset has one column ['text'].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the EnWik9Dataset.
inline std::shared_ptr<EnWik9Dataset> MS_API EnWik9(const std::string &dataset_dir, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EnWik9Dataset>(StringToChar(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache);
}
/// \class FakeImageDataset
/// \brief A source dataset for generating fake images.
class MS_API FakeImageDataset : public Dataset {

View File

@ -76,7 +76,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset, \
check_conll2000_dataset, check_amazon_review_dataset, check_semeion_dataset, check_caltech101_dataset, \
check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset, check_wider_face_dataset
check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset, check_wider_face_dataset, \
check_en_wik9_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -10810,6 +10811,82 @@ class STL10Dataset(MappableDataset):
return cde.STL10Node(self.dataset_dir, self.usage, self.sampler)
class EnWik9Dataset(SourceDataset):
"""
A source dataset that reads and parses EnWik9 dataset.
The generated dataset has one column :py:obj:`[text]` with type string.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, will include all samples).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=True).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).
Examples:
>>> en_wik9_dataset_dir = "/path/to/en_wik9_dataset"
>>> dataset2 = ds.EnWik9Dataset(dataset_dir=en_wik9_dataset_dir, num_samples=2,
... shuffle=True)
About EnWik9 dataset:
The data of EnWik9 is UTF-8 encoded XML consisting primarily of English text. It contains 243,426 article titles,
of which 85,560 are #REDIRECT to fix broken links, and the rest are regular articles.
The data is UTF-8 clean. All characters are in the range U'0000 to U'10FFFF with valid encodings of 1 to
4 bytes. The byte values 0xC0, 0xC1, and 0xF5-0xFF never occur. Also, in the Wikipedia dumps,
there are no control characters in the range 0x00-0x1F except for 0x09 (tab) and 0x0A (linefeed).
Linebreaks occur only on paragraph boundaries, so they always have a semantic purpose.
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
.. code-block::
.
EnWik9
enwik9
Citation:
.. code-block::
@NetworkResource{Hutter_prize,
author = {English Wikipedia},
url = "https://cs.fit.edu/~mmahoney/compression/textdata.html",
month = {March},
year = {2006}
}
"""
@check_en_wik9_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=True,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
def parse(self, children=None):
return cde.EnWik9Node(self.dataset_dir, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class YahooAnswersDataset(SourceDataset):
"""
A source dataset that reads and parses the YahooAnswers dataset.

View File

@ -2499,3 +2499,26 @@ def check_wiki_text_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_en_wik9_dataset(method):
"""Wrapper method to check the parameters of EnWik9 dataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -28,6 +28,7 @@ SET(DE_UT_SRCS
c_api_dataset_dbpedia_test.cc
c_api_dataset_div2k_test.cc
c_api_dataset_emnist_test.cc
c_api_dataset_en_wik9_test.cc
c_api_dataset_fake_image_test.cc
c_api_dataset_fashion_mnist_test.cc
c_api_dataset_flickr_test.cc

View File

@ -0,0 +1,427 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::ShuffleMode;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the number of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9DatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetBasic.";
// Test EnWik9 Dataset with single text file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a EnWik9 Dataset, with single enwik9 file.
// Note: /testEnWik9Dataset/enwik9 has 13 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9(tf_file, 2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {" <title>MindSpore</title>", " <page>"};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result.
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 2 samples.
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the number of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9DatasetBasicAndRepeat) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetBasicAndRepeat.";
// Test EnWik9 Dataset with single enwik9 file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create two EnWik9 Dataset, with single enwik9 file.
// Note: /testEnWik9Dataset/enwik9 has 13 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds1 = EnWik9(tf_file, 2);
std::shared_ptr<Dataset> ds2 = EnWik9(tf_file, 2);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {" <page>", " <title>MindSpore</title>"};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 10 samples.
EXPECT_EQ(i, 10);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the number of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9Getters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9Getters.";
// Test EnWik9 Dataset with single text file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a EnWik9 Dataset, with single enwik9 file.
// Note: /testEnWik9Dataset/enwik9 has 3 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9(tf_file, 2);
EXPECT_NE(ds, nullptr);
std::vector<std::string> column_names = {"text"};
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
ds = EnWik9(tf_file, 0);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 13);
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the argument named dataset_file is incorrect
TEST_F(MindDataTestPipeline, TestEnWik9DatasetFailNoExistentPath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetFailNoExistentPath.";
// Create a EnWik9 Dataset.
// with non-existent dataset_files input.
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9("/NotExist", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(iter, nullptr);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the data of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleFalse1A) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleFalse1A.";
// Test EnWik9 Dataset with two enwik9 files and no shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(654);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a EnWik9 Dataset, with one enwik9 file, /testEnWik9Dataset/enwik9.
// Note: /testEnWik9Dataset/enwik9 has 13 rows.
// Use default of all samples
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>"};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result.
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 13 samples.
EXPECT_EQ(i, 13);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the data of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleFalse4Shard) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleFalse4Shard.";
// Test EnWik9 Dataset with one enwik9 files and no shuffle, num_parallel_workers=4, shard coverage.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(654);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a EnWik9 Dataset, with one enwik9 file.
// Note: /testEnWik9Dataset/enwik9 has 13 rows.
// Set shuffle to file shuffle, num_shards=2, shard_id=0
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kFalse, 2, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>"};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result.
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 7 samples for this shard.
EXPECT_EQ(i, 7);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: EnWik9Dataset
/// Description: test EnWik9Dataset in pipeline mode
/// Expectation: the data of samples is proper
TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleGlobal1A) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleGlobal1A.";
// Test EnWik9 Dataset with one enwik9 file, global shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a EnWik9 Dataset, with one enwik9 file.
// Note: /testEnWik9Dataset/enwik9 has 13 rows.
// Set shuffle to global shuffle.
std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset";
std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {" </contributor>",
" <page>",
" <contributor>",
" <username>MS</username>",
" <title>MindSpore</title>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <text xml:space=\"preserve\">666</text>",
" <revision>",
" <id>567</id>",
" </revision>",
" </page>",
" <id>234</id>",
" <id>1</id>"};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result.
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 13 samples.
EXPECT_EQ(i, 13);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,13 @@
<page>
<title>MindSpore</title>
<id>1</id>
<revision>
<id>234</id>
<timestamp>2020-01-01T00:00:00Z</timestamp>
<contributor>
<username>MS</username>
<id>567</id>
</contributor>
<text xml:space="preserve">666</text>
</revision>
</page>

View File

@ -0,0 +1,296 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers, config_get_set_seed
DATA_FILE = "../data/dataset/testEnWik9Dataset"
def test_enwik9_total_rows_dataset_num_samples_none():
"""
Feature: EnWik9Dataset
Description: test the function while param num_samples = 0
Expectation: the number of samples is 13
"""
# Do not provide a num_samples argument, so it would be None by default.
data = ds.EnWik9Dataset(DATA_FILE)
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 13
def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two():
"""
Feature: EnWik9Dataset
Description: test the function while param shuffle = False
Expectation: the samples is ordered
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(2)
original_seed = config_get_set_seed(987)
data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
count = 0
line = [" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 13
# Restore configuration.
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one():
"""
Feature: EnWik9Dataset
Description: test the function while param shuffle = False
Expectation: the samples is ordered
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
count = 0
line = [" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 13
# Restore configuration.
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two():
"""
Feature: EnWik9Dataset
Description: test the function while param shuffle = True
Expectation: the samples is disorder
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(2)
original_seed = config_get_set_seed(135)
data = ds.EnWik9Dataset(DATA_FILE, shuffle=True)
count = 0
line = [" <username>MS</username>",
" <title>MindSpore</title>",
" <id>234</id>",
" </revision>",
" </contributor>",
" <revision>",
" <id>567</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <id>1</id>",
" </page>",
" <page>",
" <text xml:space=\"preserve\">666</text>",
" <contributor>"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 13
# Restore configuration.
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one():
"""
Feature: EnWik9Dataset
Description: test the function while param shuffle = True
Expectation: the samples is disorder
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(135)
data = ds.EnWik9Dataset(DATA_FILE, shuffle=True)
count = 0
line = [" <username>MS</username>",
" <title>MindSpore</title>",
" <id>234</id>",
" </revision>",
" </contributor>",
" <revision>",
" <id>567</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <id>1</id>",
" </page>",
" <page>",
" <text xml:space=\"preserve\">666</text>",
" <contributor>"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 13
# Restore configuration.
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_enwik9_dataset_num_samples():
"""
Feature: EnWik9Dataset
Description: test param num_samples, while it = 2
Expectation: the number of samples = 2
"""
data = ds.EnWik9Dataset(DATA_FILE, num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_enwik9_dataset_distribution():
"""
Feature: EnWik9Dataset
Description: test distribution of the dataset
Expectation: count = 7
"""
data = ds.EnWik9Dataset(DATA_FILE, num_shards=2, shard_id=1)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 7
def test_enwik9_total_rows_dataset_repeat():
"""
Feature: EnWik9Dataset
Description: test the function whie the samples are repeat
Expectation: count = 26
"""
data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
data = data.repeat(2)
count = 0
line = [" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>",
" <page>",
" <title>MindSpore</title>",
" <id>1</id>",
" <revision>",
" <id>234</id>",
" <timestamp>2020-01-01T00:00:00Z</timestamp>",
" <contributor>",
" <username>MS</username>",
" <id>567</id>",
" </contributor>",
" <text xml:space=\"preserve\">666</text>",
" </revision>",
" </page>"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 26
def test_enwik9_total_rows_dataset_get_datasetsize():
"""
Feature: EnWik9Dataset
Description: test the function, get_dataset_size()
Expectation: size = 13
"""
data = ds.EnWik9Dataset(DATA_FILE)
size = data.get_dataset_size()
assert size == 13
def test_enwik9_total_rows_dataset_to_device():
"""
Feature: EnWik9Dataset
Description: test the function, to_device()
Expectation: size = 13
"""
data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
data = data.to_device()
data.send()
def test_enwik9_dataset_exceptions():
"""
Feature: EnWik9Dataset
Description: test the errors which appear possibly
Expectation: the errors are expected correctly
"""
with pytest.raises(ValueError) as error_info:
_ = ds.EnWik9Dataset("does/not/exist/")
assert "does not exist or is not a directory or permission denied" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.EnWik9Dataset("")
assert "The folder does not exist or is not a directory or permission denied" in str(error_info.value)
def exception_func(item):
raise Exception("Error occur!")
with pytest.raises(RuntimeError) as error_info:
data = ds.EnWik9Dataset(DATA_FILE)
data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
if __name__ == "__main__":
test_enwik9_total_rows_dataset_num_samples_none()
test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two()
test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one()
test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two()
test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one()
test_enwik9_dataset_num_samples()
test_enwik9_dataset_distribution()
test_enwik9_total_rows_dataset_repeat()
test_enwik9_total_rows_dataset_get_datasetsize()
test_enwik9_total_rows_dataset_to_device()
test_enwik9_dataset_exceptions()