!22222 [assistant][ops] Add new loader PennTreebankDataset

Merge pull request !22222 from 杨旭华/PennTreebankDataset
This commit is contained in:
i-robot 2021-12-13 11:43:00 +00:00 committed by Gitee
commit 2a2cd5e4ca
20 changed files with 1654 additions and 5 deletions

View File

@ -107,6 +107,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
@ -1398,6 +1399,14 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
}
#ifndef ENABLE_ANDROID
PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<PennTreebankNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -43,6 +43,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
@ -342,6 +343,18 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
*m, "PennTreebankNode", "to create a PennTreebankNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
auto penn_treebank = std::make_shared<PennTreebankNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(penn_treebank->ValidateParams());
return penn_treebank;
}));
}));
PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
(void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>(
*m, "PhotoTourNode", "to create a PhotoTourNode")

View File

@ -24,6 +24,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
mappable_leaf_op.cc
mnist_op.cc
nonmappable_leaf_op.cc
penn_treebank_op.cc
photo_tour_op.cc
places365_op.cc
qmnist_op.cc

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
PennTreebankOp::PennTreebankOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, const std::vector<std::string> &file_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), file_list, op_connector_size,
shuffle_files, num_devices, device_id) {}
// A print method typically used for debugging.
void PennTreebankOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nPennTreebank files list:\n";
for (size_t i = 0; i < text_files_list_.size(); ++i) {
out << " " << text_files_list_[i];
}
out << "\nData Schema:\n";
out << *data_schema_ << "\n\n";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/util/queue.h"
namespace mindspore {
namespace dataset {
class JaggedConnector;
class PennTreebankOp : public TextFileOp {
public:
/// \brief Constructor.
/// \param[in] num_workers Number of workers reading images in parallel
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] worker_connector_size Size of each internal queue.
/// \param[in] data_schema Path to dataset schema file.
/// \param[in] file_list List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices Number of devices that the dataset should be divided into.
/// \param[in] device_id The device ID within num_devices. This argument should be
/// specified only when num_devices is also specified.
PennTreebankOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, std::unique_ptr<DataSchema>,
const std::vector<std::string> &file_list, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id);
/// \brief Default destructor.
~PennTreebankOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out he output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "PennTreebankOp"; }
/// \brief DatasetName name getter.
/// \return DatasetName of the current Op.
std::string DatasetName(bool upper = false) const { return upper ? "PennTreebank" : "penn treebank"; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_

View File

@ -167,8 +167,7 @@ Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK();
}
// Internal helper function to calculate rows
int64_t CountTotalRows(const std::string &file) {
int64_t TextFileOp::CountTotalRows(const std::string &file) {
auto realpath = FileUtils::GetRealPath(file.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file, " << file << " does not exist.";
@ -216,9 +215,24 @@ Status TextFileOp::CalculateNumRowsPerShard() {
Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
int32_t num_workers = GlobalContext::config_manager()->num_parallel_workers();
int32_t connector_que_size = GlobalContext::config_manager()->op_connector_size();
int32_t worker_connector_size = GlobalContext::config_manager()->worker_connector_size();
const int32_t shard_id = 0;
const int32_t num_shards = 1;
const int64_t num_samples = 0;
bool shuffle_files = false;
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
// Create and initialize
std::shared_ptr<TextFileOp> op =
std::make_shared<TextFileOp>(num_workers, num_samples, worker_connector_size, std::move(schema), files,
connector_que_size, shuffle_files, num_shards, shard_id);
RETURN_IF_NOT_OK(op->Init());
*count = 0;
for (auto file : files) {
*count += CountTotalRows(file);
*count += op->CountTotalRows(file);
}
return Status::OK();
}

View File

@ -82,7 +82,7 @@ class TextFileOp : public NonMappableLeafOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; }
private:
protected:
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
@ -111,6 +111,11 @@ class TextFileOp : public NonMappableLeafOp {
// @return - Status
Status ComputeColMap() override;
// Count number of rows in each file.
// @param file - txt file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
std::vector<std::string> text_files_list_;
std::unique_ptr<DataSchema> data_schema_;
};

View File

@ -98,6 +98,7 @@ constexpr char kLJSpeechNode[] = "LJSpeechDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kPennTreebankNode[] = "PennTreebankDataset";
constexpr char kPhotoTourNode[] = "PhotoTourDataset";
constexpr char kPlaces365Node[] = "Places365Dataset";
constexpr char kQMnistNode[] = "QMnistDataset";

View File

@ -24,6 +24,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
manifest_node.cc
minddata_node.cc
mnist_node.cc
penn_treebank_node.cc
photo_tour_node.cc
places365_node.cc
qmnist_node.cc

View File

@ -0,0 +1,198 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for PennTreebankNode.
PennTreebankNode::PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
penn_treebank_files_list_(WalkAllFiles(usage, dataset_dir)) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> PennTreebankNode::Copy() {
auto node =
std::make_shared<PennTreebankNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void PennTreebankNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status PennTreebankNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("PennTreebankNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("PennTreebankNode", usage_, {"train", "test", "valid", "all"}));
if (num_samples_ < 0) {
std::string err_msg = "PennTreebankNode: Invalid number of samples: " + std::to_string(num_samples_);
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetShardParams("PennTreebankNode", num_shards_, shard_id_));
return Status::OK();
}
// Function to build PennTreebankNode.
Status PennTreebankNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order.
std::vector<std::string> sorted_dataset_files = penn_treebank_files_list_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// Create and initialize PennTreebankNode.
std::shared_ptr<PennTreebankOp> penn_treebank_op =
std::make_shared<PennTreebankOp>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(penn_treebank_op->Init());
// If a global shuffle is used for PennTreebank, it will inject a shuffle op over the PennTreebank.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset PennTreebank's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp.
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset.
RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows));
// Add the shuffle op after this op.
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
penn_treebank_op->SetTotalRepeats(GetTotalRepeats());
penn_treebank_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add PennTreebankNode.
node_ops->push_back(penn_treebank_op);
return Status::OK();
}
// Get the shard id of node.
Status PennTreebankNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size.
Status PennTreebankNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status PennTreebankNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// PennTreebank by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status PennTreebankNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this PennTreebank node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the PennTreebank node need to be reset to its defaults so
// that this PennTreebank node will produce the full set of data into the cache.
Status PennTreebankNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
std::vector<std::string> PennTreebankNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
std::vector<std::string> penn_treebank_files_list;
Path train_prefix("ptb.train.txt");
Path test_prefix("ptb.test.txt");
Path valid_prefix("ptb.valid.txt");
Path dir(dataset_dir);
if (usage == "train") {
Path temp_path = dir / train_prefix;
penn_treebank_files_list.push_back(temp_path.ToString());
} else if (usage == "test") {
Path temp_path = dir / test_prefix;
penn_treebank_files_list.push_back(temp_path.ToString());
} else if (usage == "valid") {
Path temp_path = dir / valid_prefix;
penn_treebank_files_list.push_back(temp_path.ToString());
} else {
Path temp_path = dir / train_prefix;
penn_treebank_files_list.push_back(temp_path.ToString());
Path temp_path1 = dir / test_prefix;
penn_treebank_files_list.push_back(temp_path1.ToString());
Path temp_path2 = dir / valid_prefix;
penn_treebank_files_list.push_back(temp_path2.ToString());
}
return penn_treebank_files_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,124 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \brief class PennTreebankNode.
/// \brief Dataset derived class to represent PennTreebank dataset.
class PennTreebankNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~PennTreebankNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kPennTreebankNode; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
int32_t NumSamples() const { return num_samples_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
ShuffleMode Shuffle() const { return shuffle_; }
const std::string &Usage() const { return usage_; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief PennTreebank by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in
/// the tree, that cache can inherit this sampler from the leaf, providing
/// sampling support from the caching layer. That is why we setup the
/// sampler for a leaf node that does not use sampling. Note: This
/// function is common among NonMappableSourceNode and should be promoted
/// to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this PennTreebank node,
/// then the cache will be executing a sampler for fetching the data.
/// As such, any options in the PennTreebank node need to be reset to its defaults
/// so that this PennTreebank node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its
/// parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
/// \brief Generate a list of read file names according to usage.
/// \param[in] usage Part of dataset of PennTreebank.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \return std::vector<std::string> A list of read file names.
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
private:
std::string dataset_dir_;
std::string usage_;
int64_t num_samples_;
int32_t num_shards_;
int32_t shard_id_;
ShuffleMode shuffle_;
std::vector<std::string> penn_treebank_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_

View File

@ -3175,6 +3175,58 @@ inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \class PennTreebankDataset
/// \brief A source dataset for reading and parsing PennTreebank dataset.
class MS_API PennTreebankDataset : public Dataset {
public:
/// \brief Constructor of PennTreebank Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data list txt file to be read, can be "train", "test", 'valid' or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of PennTreebankDataset.
~PennTreebankDataset() = default;
};
/// \brief Function to create a PennTreebank Dataset.
/// \note The generated dataset has one column ['text'].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage One of "all", "train" , 'valid' or "test" (default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] usage One of "all", "train", "valid" or "test" (default = "all").
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the TextFileDataset.
inline std::shared_ptr<PennTreebankDataset> MS_API PennTreebank(const std::string &dataset_dir,
const std::string &usage = "all",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<PennTreebankDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
/// \class PhotoTourDataset
/// \brief A source dataset for reading and parsing PhotoTour dataset.
class MS_API PhotoTourDataset : public Dataset {

View File

@ -71,7 +71,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \
check_stl10_dataset, check_yelp_review_dataset
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size, get_auto_offload
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -3953,6 +3953,95 @@ class MnistDataset(MappableDataset):
return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
class PennTreebankDataset(SourceDataset):
"""
A source dataset that reads and parses PennTreebank datasets.
The generated dataset has one column :py:obj:`[text]`.
The tensor of column :py:obj:`text` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Acceptable usages include `train`, `test`, 'valid' and `all`.
'train' will read from 42,068 train samples of string type,
'test' will read from 3,370 test samples of string type,
'valid' will read from 3,761 test samples of string type,
'all' will read from all 49,199 samples of string type (default=None, all samples).
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, 'num_samples' reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Examples:
>>> penn_treebank_dataset_dir = "path/to/penn_treebank_dataset_directory"
>>> dataset = ds.PennTreebankDataset(dataset_dir=penn_treebank_dataset_dir, usage='all')
About PennTreebank dataset:
Penn Treebank (PTB) dataset, is widely used in machine learning for NLP (Natural Language Processing)
research. Word-level PTB does not contain capital letters, numbers, and punctuations, and the vocabulary
is capped at 10k unique words, which is relatively small in comparison to most modern datasets which
can result in a larger number of out of vocabulary tokens.
Here is the original PennTreebank dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
PennTreebank_dataset_dir
ptb.test.txt
ptb.train.txt
ptb.valid.txt
Citation:
.. code-block::
@techreport{Santorini1990,
added-at = {2014-03-26T23:25:56.000+0100},
author = {Santorini, Beatrice},
biburl = {https://www.bibsonomy.org/bibtex/234cdf6ddadd89376090e7dada2fc18ec/butonic},
file = {:Santorini - Penn Treebank tag definitions.pdf:PDF},
institution = {Department of Computer and Information Science, University of Pennsylvania},
interhash = {818e72efd9e4b5fae3e51e88848100a0},
intrahash = {34cdf6ddadd89376090e7dada2fc18ec},
keywords = {dis pos tagging treebank},
number = {MS-CIS-90-47},
timestamp = {2014-03-26T23:25:56.000+0100},
title = {Part-of-speech tagging guidelines for the {P}enn {T}reebank {P}roject},
url = {ftp://ftp.cis.upenn.edu/pub/treebank/doc/tagguide.ps.gz},
year = 1990
}
"""
@check_penn_treebank_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.PennTreebankNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class PhotoTourDataset(MappableDataset):
"""
A source dataset for reading and parsing the PhotoTour dataset.

View File

@ -1188,6 +1188,35 @@ def check_textfiledataset(method):
return new_method
def check_penn_treebank_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(PennTreebankDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
# check usage
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_split(method):
"""check the input arguments of split."""

View File

@ -34,6 +34,7 @@ SET(DE_UT_SRCS
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_penn_treebank_test.cc
c_api_dataset_photo_tour_test.cc
c_api_dataset_places365_test.cc
c_api_dataset_qmnist_test.cc

View File

@ -0,0 +1,588 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::ShuffleMode;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasic.";
// Test PennTreebank Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" clash twits poetry formulate flip loyalty splash "},
{" you pay less for the supermaket's own brands "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasicWithPipeline.";
// Test PennTreebank Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds1 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" clash twits poetry formulate flip loyalty splash "},
{" you pay less for the supermaket's own brands "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 15 samples
EXPECT_EQ(i, 15);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankGetters.";
// Test PennTreebank Dataset with single text file and many default inputs
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::vector<std::string> column_names = {"text"};
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 1);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].ToString(), "<>");
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail1.";
// Create a PennTreebank Dataset
// with invalid samplers=-1
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", -1, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: PennTreebank number of samples cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail2.";
// Attempt to create a PennTreebank Dataset
// with wrongful empty dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank("123", "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: dataset_dir is not specified
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail3.";
// Create a PennTreebank Dataset
// with non-existent dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "asd", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid usage
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail4.";
// Create a PennTreebank Dataset
// with empty string dataset_files input
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank("", "test", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail5) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail5.";
// Create a PennTreebank Dataset
// with invalid num_shards=0 value
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Number of shards cannot be <=0
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail6) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail6.";
// Create a PennTreebank Dataset
// with invalid shard_id=-1 value
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 1, -1);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: shard_id cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail7) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail7.";
// Create a PennTreebank Dataset
// with invalid shard_id=2 and num_shards=2 combination
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Cannot have shard_id >= num_shards
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test PennTreebank Dataset.
/// Description: Testing abnormal inputs.
/// Expectation: Exception thrown to be caught.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFalse) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFalse.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(2);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" does the bank charge a fee for setting up the account "},
{" clash twits poetry formulate flip loyalty splash "},
{" <unk> the wardrobe was very small in our room "},
{" you pay less for the supermaket's own brands "},
{" black white grapes "},
{" just ahead of them there was a huge fissure "},
{" <unk> <unk> the proportion of female workers in this company <unk> <unk> "},
{" everyone in our football team is fuming "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFilesA.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(654);
GlobalContext::config_manager()->set_num_parallel_workers(1);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" does the bank charge a fee for setting up the account "},
{" <unk> the wardrobe was very small in our room "},
{" black white grapes "},
{" no it was black friday "},
{" clash twits poetry formulate flip loyalty splash "},
{" you pay less for the supermaket's own brands "},
{" just ahead of them there was a huge fissure "},
{" <unk> <unk> the proportion of female workers in this company <unk> <unk> "},
{" everyone in our football team is fuming "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFilesB.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kInfile);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" no it was black friday "},
{" does the bank charge a fee for setting up the account "},
{" just ahead of them there was a huge fissure "},
{" clash twits poetry formulate flip loyalty splash "},
{" <unk> the wardrobe was very small in our room "},
{" <unk> <unk> the proportion of female workers in this company <unk> <unk> "},
{" you pay less for the supermaket's own brands "},
{" black white grapes "},
{" everyone in our football team is fuming "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test PennTreebank Dataset.
/// Description: read PennTreebank data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleGlobal.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a TextFile Dataset, with two text files
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Set shuffle to global shuffle
std::string dataset_dir = datasets_root_path_ + "/testPennTreebank";
std::shared_ptr<Dataset> ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected_result = {
{" everyone in our football team is fuming "},
{" does the bank charge a fee for setting up the account "},
{" clash twits poetry formulate flip loyalty splash "},
{" no it was black friday "},
{" just ahead of them there was a huge fissure "},
{" <unk> <unk> the proportion of female workers in this company <unk> <unk> "},
{" you pay less for the supermaket's own brands "},
{" <unk> the wardrobe was very small in our room "},
{" black white grapes "},
};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 9 samples
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,3 @@
no it was black friday
clash twits poetry formulate flip loyalty splash
you pay less for the supermaket's own brands

View File

@ -0,0 +1,3 @@
does the bank charge a fee for setting up the account
<unk> the wardrobe was very small in our room
black white grapes

View File

@ -0,0 +1,3 @@
just ahead of them there was a huge fissure
<unk> <unk> the proportion of female workers in this company <unk> <unk>
everyone in our football team is fuming

View File

@ -0,0 +1,385 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers, config_get_set_seed
FILE_DIR = '../data/dataset/testPennTreebank'
def test_penn_treebank_dataset_one_file():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='test')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 3
def test_penn_treebank_dataset_train():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='train')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 3
def test_penn_treebank_dataset_valid():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='valid')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 3
def test_penn_treebank_dataset_all_file():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='all')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 9
def test_penn_treebank_dataset_num_samples_none():
"""
Feature: Test PennTreebank Dataset.
Description: read data with no num_samples input.
Expectation: the data is processed successfully.
"""
# Do not provide a num_samples argument, so it would be None by default
data = ds.PennTreebankDataset(FILE_DIR, usage='all')
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 9
def test_penn_treebank_dataset_shuffle_false4():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is false.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(987)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False)
count = 0
line = [" no it was black friday ",
" does the bank charge a fee for setting up the account ",
" just ahead of them there was a huge fissure ",
" clash twits poetry formulate flip loyalty splash ",
" <unk> the wardrobe was very small in our room ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" you pay less for the supermaket's own brands ",
" black white grapes ",
" everyone in our football team is fuming "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_shuffle_false1():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is false.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False)
count = 0
line = [" no it was black friday ",
" clash twits poetry formulate flip loyalty splash ",
" you pay less for the supermaket's own brands ",
" does the bank charge a fee for setting up the account ",
" <unk> the wardrobe was very small in our room ",
" black white grapes ",
" just ahead of them there was a huge fissure ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" everyone in our football team is fuming "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_shuffle_files4():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is files.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(135)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
count = 0
line = [" just ahead of them there was a huge fissure ",
" does the bank charge a fee for setting up the account ",
" no it was black friday ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" <unk> the wardrobe was very small in our room ",
" clash twits poetry formulate flip loyalty splash ",
" everyone in our football team is fuming ",
" black white grapes ",
" you pay less for the supermaket's own brands "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_shuffle_files1():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is files.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(135)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
count = 0
line = [" just ahead of them there was a huge fissure ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" everyone in our football team is fuming ",
" does the bank charge a fee for setting up the account ",
" <unk> the wardrobe was very small in our room ",
" black white grapes ",
" no it was black friday ",
" clash twits poetry formulate flip loyalty splash ",
" you pay less for the supermaket's own brands "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_shuffle_global4():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is global.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(246)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
count = 0
line = [" everyone in our football team is fuming ",
" does the bank charge a fee for setting up the account ",
" clash twits poetry formulate flip loyalty splash ",
" no it was black friday ",
" just ahead of them there was a huge fissure ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" you pay less for the supermaket's own brands ",
" <unk> the wardrobe was very small in our room ",
" black white grapes "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_shuffle_global1():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file with shulle is global.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(246)
data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
count = 0
line = [" everyone in our football team is fuming ",
" does the bank charge a fee for setting up the account ",
" clash twits poetry formulate flip loyalty splash ",
" <unk> the wardrobe was very small in our room ",
" black white grapes ",
" you pay less for the supermaket's own brands ",
" <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
" no it was black friday ",
" just ahead of them there was a huge fissure "]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_penn_treebank_dataset_num_samples():
"""
Feature: Test PennTreebank Dataset.
Description: Test num_samples.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_penn_treebank_dataset_distribution():
"""
Feature: Test PennTreebank Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 5
def test_penn_treebank_dataset_repeat():
"""
Feature: Test PennTreebank Dataset.
Description: Test repeat.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='test', shuffle=False)
data = data.repeat(3)
count = 0
line = [" no it was black friday ",
" clash twits poetry formulate flip loyalty splash ",
" you pay less for the supermaket's own brands ",
" no it was black friday ",
" clash twits poetry formulate flip loyalty splash ",
" you pay less for the supermaket's own brands ",
" no it was black friday ",
" clash twits poetry formulate flip loyalty splash ",
" you pay less for the supermaket's own brands ",]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 9
def test_penn_treebank_dataset_get_datasetsize():
"""
Feature: Test PennTreebank Dataset.
Description: Test get_datasetsize.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='test')
size = data.get_dataset_size()
assert size == 3
def test_penn_treebank_dataset_to_device():
"""
Feature: Test PennTreebank Dataset.
Description: Test to_device.
Expectation: the data is processed successfully.
"""
data = ds.PennTreebankDataset(FILE_DIR, usage='test')
data = data.to_device()
data.send()
def test_penn_treebank_dataset_exceptions():
"""
Feature: Test PennTreebank Dataset.
Description: Test exceptions.
Expectation: Exception thrown to be caught
"""
with pytest.raises(ValueError) as error_info:
_ = ds.PennTreebankDataset(FILE_DIR, usage='test', num_samples=-1)
assert "num_samples exceeds the boundary" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.PennTreebankDataset("does/not/exist/no.txt")
assert str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.PennTreebankDataset("")
assert str(error_info.value)
def exception_func(item):
raise Exception("Error occur!")
with pytest.raises(RuntimeError) as error_info:
data = ds.PennTreebankDataset(FILE_DIR)
data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
if __name__ == "__main__":
test_penn_treebank_dataset_one_file()
test_penn_treebank_dataset_train()
test_penn_treebank_dataset_valid()
test_penn_treebank_dataset_all_file()
test_penn_treebank_dataset_num_samples_none()
test_penn_treebank_dataset_shuffle_false4()
test_penn_treebank_dataset_shuffle_false1()
test_penn_treebank_dataset_shuffle_files4()
test_penn_treebank_dataset_shuffle_files1()
test_penn_treebank_dataset_shuffle_global4()
test_penn_treebank_dataset_shuffle_global1()
test_penn_treebank_dataset_num_samples()
test_penn_treebank_dataset_distribution()
test_penn_treebank_dataset_repeat()
test_penn_treebank_dataset_get_datasetsize()
test_penn_treebank_dataset_to_device()
test_penn_treebank_dataset_exceptions()