[feat][assistant][I40GXT] add new loader DBpedia

This commit is contained in:
Carry955 2021-11-11 19:02:22 -08:00
parent a304352179
commit 0d26c38693
17 changed files with 1324 additions and 2 deletions

View File

@ -96,6 +96,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
@ -1029,6 +1030,14 @@ CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
DBpediaDataset::DBpediaDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<DBpediaNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DBpediaNode>(ds);
}
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {

View File

@ -33,6 +33,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
@ -158,6 +159,18 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
(void)py::class_<DBpediaNode, DatasetNode, std::shared_ptr<DBpediaNode>>(*m, "DBpediaNode",
"to create a DBpediaNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
auto dbpedia = std::make_shared<DBpediaNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(dbpedia->ValidateParams());
return dbpedia;
}));
}));
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
"to create a DIV2KNode")

View File

@ -28,6 +28,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
photo_tour_op.cc
fashion_mnist_op.cc
ag_news_op.cc
dbpedia_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/dbpedia_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <stdexcept>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
DBpediaOp::DBpediaOp(const std::vector<std::string> &dataset_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t num_samples,
int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
int32_t device_id)
: CsvOp(dataset_files_list, field_delim, column_default, column_name, num_workers, num_samples,
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id) {}
void DBpediaOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nDBpedia files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) {
out << " " << csv_files_list_[i];
}
out << "\n\n";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DBPEDIA_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DBPEDIA_OP_H_
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
namespace mindspore {
namespace dataset {
class DBpediaOp : public CsvOp {
public:
/// \brief Constructor.
/// \param[in] dataset_files_list - List of file paths for the dataset files.
/// \param[in] field_delim - A char that indicates the delimiter to separate fields.
/// \param[in] column_default - List of default values for the CSV field (default={}). Each item in the list is
/// either a valid type (float, int, or string).
/// \param[in] column_name - List of column names of the dataset file.
/// \param[in] num_workers - Num of workers reading files in parallel.
/// \param[in] num_samples - The number of samples to be included in the dataset.
/// \param[in] worker_connector_size - Size of each internal queue.
/// \param[in] op_connector_size - Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files - Whether or not to shuffle the files before reading data.
/// \param[in] num_devices - Number of devices that the dataset should be divided into.
/// \param[in] device_id - The device ID within num_devices.
DBpediaOp(const std::vector<std::string> &dataset_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
/// \brief Destructor.
~DBpediaOp() = default;
/// A print method typically used for debugging
/// @param out - The output stream to write output to
/// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \brief DatasetName name getter.
/// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name.
/// \return DatasetName of the current Op.
std::string DatasetName(bool upper = false) const { return upper ? "DBpedia" : "dbpedia"; }
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "DBpediaOp"; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DBPEDIA_OP_H_

View File

@ -83,6 +83,7 @@ constexpr char kCityscapesNode[] = "CityscapesDataset";
constexpr char kCLUENode[] = "CLUEDataset";
constexpr char kCocoNode[] = "CocoDataset";
constexpr char kCSVNode[] = "CSVDataset";
constexpr char kDBpediaNode[] = "DBpediaDataset";
constexpr char kDIV2KNode[] = "DIV2KDataset";
constexpr char kEMnistNode[] = "EMnistDataset";
constexpr char kFakeImageNode[] = "FakeImageDataset";

View File

@ -12,6 +12,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
clue_node.cc
coco_node.cc
csv_node.cc
dbpedia_node.cc
div2k_node.cc
emnist_node.cc
fake_image_node.cc

View File

@ -0,0 +1,209 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
DBpediaNode::DBpediaNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> DBpediaNode::Copy() {
auto node =
std::make_shared<DBpediaNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void DBpediaNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status DBpediaNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("DBpediaNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("DBpediaNode", usage_, {"train", "test", "all"}));
if (num_samples_ < 0) {
std::string err_msg = "DBpediaNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetShardParams("DBpediaNode", num_shards_, shard_id_));
return Status::OK();
}
// Function to build DBpediaNode.
Status DBpediaNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order.
std::vector<std::string> sorted_dataset_files;
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &sorted_dataset_files));
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
char field_delim = ',';
std::vector<std::string> column_names = {"class", "title", "content"};
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
for (auto c : column_names) {
column_default_list.push_back(std::make_shared<DBpediaOp::Record<std::string>>(DBpediaOp::STRING, ""));
}
std::shared_ptr<DBpediaOp> dbpedia_op = std::make_shared<DBpediaOp>(
sorted_dataset_files, field_delim, column_default_list, column_names, num_workers_, num_samples_,
worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(dbpedia_op->Init());
// If a global shuffle is used for DBpedia, it will inject a shuffle op over the DBpedia.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset DBpedia's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp.
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset.
RETURN_IF_NOT_OK(DBpediaOp::CountAllFileRows(sorted_dataset_files, column_names.empty(), &num_rows));
// Add the shuffle op after this op.
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
dbpedia_op->SetTotalRepeats(GetTotalRepeats());
dbpedia_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(dbpedia_op);
return Status::OK();
}
Status DBpediaNode::WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
std::vector<std::string> *dataset_files) {
Path train_file_name("train.csv");
Path test_file_name("test.csv");
Path dir(dataset_dir);
if (usage == "train") {
Path file_path = dir / train_file_name;
dataset_files->push_back(file_path.ToString());
} else if (usage == "test") {
Path file_path = dir / test_file_name;
dataset_files->push_back(file_path.ToString());
} else {
Path file_path_1 = dir / train_file_name;
dataset_files->push_back(file_path_1.ToString());
Path file_path_2 = dir / test_file_name;
dataset_files->push_back(file_path_2.ToString());
}
return Status::OK();
}
// Get the shard id of node.
Status DBpediaNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size.
Status DBpediaNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
std::vector<std::string> column_names = {"class", "title", "content"};
std::vector<std::string> dataset_files;
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &dataset_files));
RETURN_IF_NOT_OK(DBpediaOp::CountAllFileRows(dataset_files, column_names.empty(), &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status DBpediaNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// DBpedia by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status DBpediaNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this DBpedia node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the DBpedia node need to be reset to its defaults so
// that this DBpedia node will produce the full set of data into the cache.
Status DBpediaNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,120 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DBPEDIA_NODE_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DBPEDIA_NODE_H
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/dbpedia_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class DBpediaNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
DBpediaNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~DBpediaNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kDBpediaNode; }
/// \brief Print the description.
/// \param out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Generate a list of read file names according to usage.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of YahooAnswers.
/// \param[in] dataset_files List of filepaths for the dataset files
/// \return Status of the function.
Status WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
std::vector<std::string> *dataset_files);
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
int64_t NumSamples() const { return num_samples_; }
ShuffleMode Shuffle() const { return shuffle_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief DBpedia by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this DBpedia node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the DBpedia node need to be reset to its defaults so
/// that this DBpedia node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
private:
std::string dataset_dir_;
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DBPEDIA_NODE_H_

View File

@ -1855,6 +1855,54 @@ inline std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_f
cache);
}
/// \class DBpediaDataset
/// \brief A source dataset for reading and parsing DBpedia dataset.
class DBpediaDataset : public Dataset {
public:
/// \brief Constructor of DBpediaDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of DBpedia, can be "train", "test" or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
DBpediaDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of DBpediaDataset.
~DBpediaDataset() = default;
};
/// \brief Function to create a DBpediaDataset.
/// \note The generated dataset has three columns ["class", "title", "content"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of DBpedia, can be "train", "test" or "all" (default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode::kGlobal).
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the DBpediaDataset
inline std::shared_ptr<DBpediaDataset> DBpedia(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DBpediaDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
/// \class DIV2KDataset
/// \brief A source dataset for reading and parsing DIV2K dataset.
class DIV2KDataset : public Dataset {

View File

@ -68,7 +68,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_ag_news_dataset
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -7873,6 +7873,102 @@ class CityscapesDataset(MappableDataset):
return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler)
class DBpediaDataset(SourceDataset):
"""
A source dataset that reads and parses the DBpedia dataset.
The generated dataset has three columns :py:obj:`[class, title, content]`.
The tensor of column :py:obj:`class` is of the string type.
The tensor of column :py:obj:`title` is of the string type.
The tensor of column :py:obj:`content` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test` or `all`.
`train` will read from 560,000 train samples,
`test` will read from 70,000 test samples,
`all` will read from all 630,000 samples (default=None, all samples).
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, will include all text).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL;
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> dbpedia_dataset_dir = "/path/to/dbpedia_dataset_directory"
>>>
>>> # 1) Read 3 samples from DBpedia dataset
>>> dataset = ds.DBpediaDataset(dataset_dir=dbpedia_dataset_dir, num_samples=3)
>>>
>>> # 2) Read train samples from DBpedia dataset
>>> dataset = ds.DBpediaDataset(dataset_dir=dbpedia_dataset_dir, usage="train")
About DBpedia dataset:
The DBpedia dataset consists of 630,000 text samples in 14 classes, there are 560,000 samples in the train.csv
and 70,000 samples in the test.csv.
The 14 different classes represent Company, EducationaInstitution, Artist, Athlete, OfficeHolder,
MeanOfTransportation, Building, NaturalPlace, Village, Animal, Plant, Album, Film, WrittenWork.
Here is the original DBpedia dataset structure.
You can unzip the dataset files into this directory structure and read by Mindspore's API.
.. code-block::
.
dbpedia_dataset_dir
train.csv
test.csv
classes.txt
readme.txt
.. code-block::
@article{DBpedia,
title = {DBPedia Ontology Classification Dataset},
author = {Jens Lehmann, Robert Isele, Max Jakob, Anja Jentzsch, Dimitris Kontokostas,
Pablo N. Mendes, Sebastian Hellmann, Mohamed Morsey, Patrick van Kleef,
Sören Auer, Christian Bizer},
year = {2015},
howpublished = {http://dbpedia.org}
}
"""
@check_dbpedia_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.DBpediaNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class DIV2KDataset(MappableDataset):
"""
A source dataset for reading and parsing DIV2KDataset dataset.

View File

@ -1752,4 +1752,31 @@ def check_ag_news_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_dbpedia_dataset(method):
"""A wrapper that wraps a parameter checker around the original DBpediaDataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -22,6 +22,7 @@ SET(DE_UT_SRCS
c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_dbpedia_test.cc
c_api_dataset_div2k_test.cc
c_api_dataset_emnist_test.cc
c_api_dataset_fake_image_test.cc

View File

@ -0,0 +1,527 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: DBpedia.
/// Description: read test data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetBasic.";
// Create a DBpedia Dataset
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterator the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"5", "My Bedroom", "Look at this room. It's my bedroom."},
{"8", "My English teacher", "She has two big eyes and a small mouth."},
{"6", "My Holiday", "I have a lot of fun every day."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: DBpedia.
/// Description: read train data and test data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetUsageAll) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetUsageAll.";
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "all", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterator the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"5", "My Bedroom", "Look at this room. It's my bedroom."},
{"7", "My Last Weekend", "I was busy last week, but I have fun every day."},
{"8", "My English teacher", "She has two big eyes and a small mouth."},
{"5", "My Friend", "She likes singing, dancing and swimming very much."},
{"6", "My Holiday", "I have a lot of fun every day."},
{"8", "I Can Do Housework", "My mother is busy, so I often help my mother with the housework."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 6 samples
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: DBpedia.
/// Description: includes tests for shape, type, size.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetGetters.";
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "test", 0, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"class", "title", "content"};
EXPECT_NE(ds, nullptr);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 3);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(types[1].ToString(), "string");
EXPECT_EQ(types[2].ToString(), "string");
EXPECT_EQ(shapes.size(), 3);
EXPECT_EQ(shapes[0].ToString(), "<>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(shapes[2].ToString(), "<>");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: DBpedia.
/// Description: read 2 samples from train file.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetNumSamples) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetNumSamples.";
// Create a DBpediaDataset
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "train", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"7", "My Last Weekend", "I was busy last week, but I have fun every day."},
{"5", "My Friend", "She likes singing, dancing and swimming very much."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 2 samples
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: DBpedia.
/// Description: test in a distributed state.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetDistribution) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetDistribution.";
// Create a DBpediaDataset
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "train", 0, ShuffleMode::kFalse, 2, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"7", "My Last Weekend", "I was busy last week, but I have fun every day."},
{"5", "My Friend", "She likes singing, dancing and swimming very much."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 2 samples
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: DBpedia.
/// Description: test with invalid input.
/// Expectation: throw error messages when certain errors occur.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetFail.";
// Create a DBpedia Dataset
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::string invalid_folder_path = "./NotExistPath";
std::vector<std::string> column_names = {"class", "title", "content"};
// Test invalid folder_path
std::shared_ptr<Dataset> ds0 = DBpedia(invalid_folder_path, "all", -1, ShuffleMode::kFalse);
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid DBpedia input
EXPECT_EQ(iter0, nullptr);
// Test invalid usage
std::shared_ptr<Dataset> ds1 = DBpedia(folder_path, "invalid_usage", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid DBpedia input
EXPECT_EQ(iter1, nullptr);
// Test invalid num_samples < -1
std::shared_ptr<Dataset> ds2 = DBpedia(folder_path, "all", -1, ShuffleMode::kFalse);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid DBpedia input
EXPECT_EQ(iter2, nullptr);
// Test invalid num_shards < 1
std::shared_ptr<Dataset> ds3 = DBpedia(folder_path, "all", 0, ShuffleMode::kFalse, 0);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid DBpedia input
EXPECT_EQ(iter3, nullptr);
// Test invalid shard_id >= num_shards
std::shared_ptr<Dataset> ds4 = DBpedia(folder_path, "all", 0, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds4, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
// Expect failure: invalid DBpedia input
EXPECT_EQ(iter4, nullptr);
}
/// Feature: DBpedia.
/// Description: read data with pipeline from test file.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetWithPipeline.";
// Create two DBpedia Dataset, with single DBpedia file.
std::string dataset_dir = datasets_root_path_ + "/testDBpedia/";
std::shared_ptr<Dataset> ds1 = DBpedia(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = DBpedia(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"class"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["class"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 15 samples.
EXPECT_EQ(i, 15);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: DBpedia.
/// Description: test with shuffle files.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetShuffleFilesA.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "all", 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"7", "My Last Weekend", "I was busy last week, but I have fun every day."},
{"5", "My Bedroom", "Look at this room. It's my bedroom."},
{"5", "My Friend", "She likes singing, dancing and swimming very much."},
{"8", "My English teacher", "She has two big eyes and a small mouth."},
{"8", "I Can Do Housework", "My mother is busy, so I often help my mother with the housework."},
{"6", "My Holiday", "I have a lot of fun every day."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 6 samples
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: DBpedia.
/// Description: test with shuffle in file.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetShuffleFilesB.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "test", 0, ShuffleMode::kInfile);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"5", "My Bedroom", "Look at this room. It's my bedroom."},
{"8", "My English teacher", "She has two big eyes and a small mouth."},
{"6", "My Holiday", "I have a lot of fun every day."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: DBpedia.
/// Description: test with global shuffle.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestDBpediaDatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBpediaDatasetShuffleFilesGlobal.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testDBpedia/";
std::vector<std::string> column_names = {"class", "title", "content"};
std::shared_ptr<Dataset> ds = DBpedia(folder_path, "test", 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("class"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"5", "My Bedroom", "Look at this room. It's my bedroom."},
{"6", "My Holiday", "I have a lot of fun every day."},
{"8", "My English teacher", "She has two big eyes and a small mouth."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,3 @@
5,"My Bedroom","Look at this room. It's my bedroom."
8,"My English teacher","She has two big eyes and a small mouth."
6,"My Holiday","I have a lot of fun every day."
1 5 My Bedroom Look at this room. It's my bedroom.
2 8 My English teacher She has two big eyes and a small mouth.
3 6 My Holiday I have a lot of fun every day.

View File

@ -0,0 +1,3 @@
7,"My Last Weekend","I was busy last week, but I have fun every day."
5,"My Friend","She likes singing, dancing and swimming very much."
8,"I Can Do Housework","My mother is busy, so I often help my mother with the housework."
1 7 My Last Weekend I was busy last week, but I have fun every day.
2 5 My Friend She likes singing, dancing and swimming very much.
3 8 I Can Do Housework My mother is busy, so I often help my mother with the housework.

View File

@ -0,0 +1,135 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
DATA_DIR = '../data/dataset/testDBpedia/'
def test_dbpedia_dataset_basic():
"""
Feature: DBpediaDataset.
Description: read data from train file.
Expectation: the data is processed successfully.
"""
buffer = []
data = ds.DBpediaDataset(DATA_DIR, usage="train", shuffle=False)
data = data.repeat(2)
data = data.skip(3)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 3
def test_dbpedia_dataset_quoted():
"""
Feature: DBpediaDataset.
Description: read the data and compare it to expectations.
Expectation: the data is processed successfully.
"""
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False)
buffer = []
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.extend([d['class'].item().decode("utf8"),
d['title'].item().decode("utf8"),
d['content'].item().decode("utf8")])
assert buffer == ["5", "My Bedroom", "Look at this room. It's my bedroom.",
"8", "My English teacher", "She has two big eyes and a small mouth.",
"6", "My Holiday", "I have a lot of fun every day."]
def test_dbpedia_dataset_usage():
"""
Feature: DBpediaDataset.
Description: read all files with usage all.
Expectation: the data is processed successfully.
"""
buffer = []
data = ds.DBpediaDataset(DATA_DIR, usage="all", shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 6
def test_dbpedia_dataset_get_datasetsize():
"""
Feature: DBpediaDataset.
Description: test get_dataset_size function.
Expectation: the data is processed successfully.
"""
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False)
size = data.get_dataset_size()
assert size == 3
def test_dbpedia_dataset_distribution():
"""
Feature: DBpediaDataset.
Description: test in a distributed state.
Expectation: the data is processed successfully.
"""
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_dbpedia_dataset_num_samples():
"""
Feature: DBpediaDataset.
Description: test num_samples parameter.
Expectation: the data is processed successfully.
"""
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_dbpedia_dataset_exception():
"""
Feature: DBpediaDataset.
Description: test the wrong input.
Expectation: Unable to read data properly.
"""
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False)
data = data.map(operations=exception_func, input_columns=["class"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.DBpediaDataset(DATA_DIR, usage="test", shuffle=False)
data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
if __name__ == "__main__":
test_dbpedia_dataset_basic()
test_dbpedia_dataset_quoted()
test_dbpedia_dataset_usage()
test_dbpedia_dataset_get_datasetsize()
test_dbpedia_dataset_distribution()
test_dbpedia_dataset_num_samples()
test_dbpedia_dataset_exception()