!19986 [assistant][ops] Add new dataset loading operator AGNEWS

Merge pull request !19986 from 杨旭华/AGNEWS
This commit is contained in:
i-robot 2021-11-11 06:22:25 +00:00 committed by Gitee
commit b910870ecc
18 changed files with 1392 additions and 3 deletions

View File

@ -83,6 +83,7 @@
#include "minddata/dataset/util/services.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
@ -851,6 +852,14 @@ std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint6
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache;
}
AGNewsDataset::AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<AGNewsNode>(CharToString(dataset_dir), num_samples, shuffle, CharToString(usage),
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#endif
AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/util/path.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
@ -62,6 +63,18 @@ namespace dataset {
// PYBIND FOR LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
(void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode",
"to create an AGNewsNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle),
usage, num_shards, shard_id, nullptr);
THROW_IF_ERROR(ag_news->ValidateParams());
return ag_news;
}));
}));
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
"to create a CelebANode")

View File

@ -27,6 +27,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
places365_op.cc
photo_tour_op.cc
fashion_mnist_op.cc
ag_news_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/ag_news_op.h"
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore {
namespace dataset {
AGNewsOp::AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, const std::vector<std::string> &ag_news_list)
: CsvOp(ag_news_list, field_delim, column_default, column_name, num_workers, num_samples, worker_connector_size,
op_connector_size, shuffle_files, num_devices, device_id) {}
// A print method typically used for debugging.
void AGNewsOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAGNews files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) {
out << " " << csv_files_list_[i];
}
out << "\n\n";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/util/auto_index.h"
namespace mindspore {
namespace dataset {
class JaggedConnector;
class AGNewsOp : public CsvOp {
public:
/// \brief Constructor.
/// \param[in] num_workers Number of workers reading images in parallel
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples).
/// \param[in] worker_connector_size Size of each internal queue.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices Number of devices that the dataset should be divided into. (Default = 1)
/// \param[in] device_id The device ID within num_devices. This argument should be
/// specified only when num_devices is also specified (Default = 0).
/// \param[in] field_delim A char that indicates the delimiter to separate fields (default=',').
/// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is
/// either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
/// \param[in] column_name List of column names of the dataset (default={}). If this is not provided, infers the
/// column_names from the first row of CSV file.
/// \param[in] ag_news_list List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
const std::vector<std::string> &ag_news_list);
/// \brief Default destructor.
~AGNewsOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out he output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a
/// summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "AGNewsOp"; }
// DatasetName name getter
// \return DatasetName of the current Op
std::string DatasetName(bool upper = false) const { return upper ? "AGNews" : "ag news"; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_

View File

@ -183,7 +183,7 @@ class CsvOp : public NonMappableLeafOp {
// \return DatasetName of the current Op
virtual std::string DatasetName(bool upper = false) const { return upper ? "CSV" : "csv"; }
private:
protected:
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.

View File

@ -74,6 +74,7 @@ constexpr char kTransferNode[] = "Transfer";
constexpr char kZipNode[] = "Zip";
// Names for leaf IR node
constexpr char kAGNewsNode[] = "AGNewsDataset";
constexpr char kAlbumNode[] = "AlbumDataset";
constexpr char kCelebANode[] = "CelebADataset";
constexpr char kCifar100Node[] = "Cifar100Dataset";

View File

@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_subdirectory(samplers)
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
ag_news_node.cc
album_node.cc
celeba_node.cc
cifar100_node.cc

View File

@ -0,0 +1,205 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/ag_news_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for AGNewsNode.
AGNewsNode::AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle,
const std::string &usage, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
usage_(usage),
ag_news_files_list_(WalkAllFiles(usage, dataset_dir)) {
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> AGNewsNode::Copy() {
auto node =
std::make_shared<AGNewsNode>(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_);
return node;
}
void AGNewsNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status AGNewsNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("AGNewsNode", usage_, {"train", "test", "all"}));
if (num_samples_ < 0) {
std::string err_msg = "AGNewsNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (num_shards_ < 1) {
std::string err_msg = "AGNewsNode: Invalid number of shards: " + std::to_string(num_shards_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsNode", num_shards_, shard_id_));
if (!column_names_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsNode", "column_names", column_names_));
}
return Status::OK();
}
// Function to build AGNewsNode.
Status AGNewsNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order.
std::vector<std::string> sorted_dataset_files = ag_news_files_list_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
// Because AGNews does not have external column_defaults nor column_names parameters,
// they need to be set before AGNewsOp is initialized.
// AGNews data set is formatted as three columns of data, so three columns are added.
std::vector<std::shared_ptr<AGNewsOp::BaseRecord>> column_default;
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
std::vector<std::string> column_name = {"index", "title", "description"};
// AGNews data values are always delimited by a comma.
char field_delim_ = ',';
std::shared_ptr<AGNewsOp> ag_news_op =
std::make_shared<AGNewsOp>(num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
num_shards_, shard_id_, field_delim_, column_default, column_name, sorted_dataset_files);
RETURN_IF_NOT_OK(ag_news_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp.
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset.
RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows));
// Add the shuffle op after this op.
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
ag_news_op->SetTotalRepeats(GetTotalRepeats());
ag_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(ag_news_op);
return Status::OK();
}
// Get the shard id of node.
Status AGNewsNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size.
Status AGNewsNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status AGNewsNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and
// should be promoted to its parent class. AGNews (for which internally is based off CSV)
// by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree,
// that cache can inherit this sampler from the leaf, providing sampling support from
// the caching layer.
// Should be promoted to its parent class.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status AGNewsNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this AGNews node, then
// the cache will be executing a sampler for fetching the data. As such, any
// options in the AGNews node need to be reset to its defaults so that this
// AGNews node will produce the full set of data into the cache.
Status AGNewsNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
std::vector<std::string> AGNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
std::vector<std::string> ag_news_files_list;
Path train_prefix("train.csv");
Path test_prefix("test.csv");
Path dir(dataset_dir);
if (usage == "train") {
Path temp_path = dir / train_prefix;
ag_news_files_list.push_back(temp_path.ToString());
} else if (usage == "test") {
Path temp_path = dir / test_prefix;
ag_news_files_list.push_back(temp_path.ToString());
} else {
Path temp_path = dir / train_prefix;
ag_news_files_list.push_back(temp_path.ToString());
Path temp_path1 = dir / test_prefix;
ag_news_files_list.push_back(temp_path1.ToString());
}
return ag_news_files_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \brief class AGNewsNode.
/// \brief Dataset derived class to represent AGNews dataset.
class AGNewsNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, const std::string &usage,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~AGNewsNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kAGNewsNode; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
int64_t NumSamples() const { return num_samples_; }
ShuffleMode Shuffle() const { return shuffle_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief AGNews by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in
/// the tree, that cache can inherit this sampler from the leaf, providing
/// sampling support from the caching layer. That is why we setup the
/// sampler for a leaf node that does not use sampling. Note: This
/// function is common among NonMappableSourceNode and should be promoted
/// to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this ag_news node,
/// then the cache will be executing a sampler for fetching the data.
/// As such, any options in the AGNews node need to be reset to its defaults
/// so that this AGNews node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its
/// parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
/// \brief Generate a list of read file names according to usage.
/// \param[in] usage Part of dataset of AGNews.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \return std::vector<std::string> A list of read file names.
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
private:
std::string dataset_dir_;
std::string usage_;
char field_delim_;
std::vector<std::shared_ptr<CsvBase>> column_defaults_;
std::vector<std::string> column_names_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
std::vector<std::string> ag_news_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_

View File

@ -993,6 +993,56 @@ inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
return SchemaCharIF(StringToChar(schema_file));
}
/// \class AGNewsDataset
/// \brief A source dataset that reads and parses AG News datasets.
class AGNewsDataset : public Dataset {
public:
/// \brief Constructor of AGNewsDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data list csv file to be read, can be "train", "test" or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of AGNewsDataset.
~AGNewsDataset() = default;
};
/// \brief Function to create a AGNewsDataset.
/// \note The generated dataset has three columns ['index', 'title', 'description'].
/// The index range is [1, 4].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage One of "all", "train" or "test" (default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch
/// (Default=ShuffleMode::kGlobal).
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use.(default=nullptr which means no cache is used).
/// \return Shared pointer to the AGNewsDataset.
inline std::shared_ptr<AGNewsDataset> AGNews(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AGNewsDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
/// \class AlbumDataset
/// \brief A source dataset for reading and parsing Album dataset.
class AlbumDataset : public Dataset {

View File

@ -67,7 +67,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset
check_photo_tour_dataset, check_ag_news_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -5075,6 +5075,93 @@ class ManifestDataset(MappableDataset):
return self.class_indexing
class AGNewsDataset(SourceDataset):
"""
A source dataset that reads and parses AG News datasets.
The generated dataset has three columns: :py:obj:`[index, title, description]`.
The tensor of column :py:obj:`index` is of the string type.
The tensor of column :py:obj:`title` is of the string type.
The tensor of column :py:obj:`description` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Acceptable usages include `train`, `test` and `all` (default=None, all samples).
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, 'num_samples' reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Examples:
>>> ag_news_dataset_dir = "/path/to/ag_news_dataset_file"
>>> dataset = ds.AGNewsDataset(dataset_dir=ag_news_dataset_dir, usage='all')
About AGNews dataset:
AG is a collection of over 1 million news articles. The news articles were collected
by ComeToMyHead from over 2,000 news sources in over 1 year of activity. ComeToMyHead
is an academic news search engine that has been in operation since July 2004.
The dataset is provided by academics for research purposes such as data mining
(clustering, classification, etc.), information retrieval (ranking, searching, etc.),
xml, data compression, data streaming, and any other non-commercial activities.
AG's news topic classification dataset was constructed by selecting the four largest
classes from the original corpus. Each class contains 30,000 training samples and
1,900 test samples. The total number of training samples in train.csv is 120,000
and the number of test samples in test.csv is 7,600.
You can unzip the dataset files into the following structure and read by MindSpore's API:
.. code-block::
.
ag_news_dataset_dir
classes.txt
train.csv
test.csv
readme.txt
Citation:
.. code-block::
@misc{zhang2015characterlevel,
title={Character-level Convolutional Networks for Text Classification},
author={Xiang Zhang and Junbo Zhao and Yann LeCun},
year={2015},
eprint={1509.01626},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
"""
@check_ag_news_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.AGNewsNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class Cifar10Dataset(MappableDataset):
"""
A source dataset for reading and parsing Cifar10 dataset.

View File

@ -535,7 +535,7 @@ def check_generatordataset(method):
raise ValueError("Neither columns_names nor schema are provided.")
if schema is not None:
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
if not isinstance(schema, (datasets.Schema, str)):
raise ValueError("schema should be a path to schema file or a schema object.")
# check optional argument
@ -1728,3 +1728,33 @@ def check_fake_image_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_ag_news_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
# check usage
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -14,6 +14,7 @@ SET(DE_UT_SRCS
c_api_audio_a_to_q_test.cc
c_api_audio_r_to_z_test.cc
c_api_cache_test.cc
c_api_dataset_ag_news_test.cc
c_api_dataset_album_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_cityscapes_test.cc

View File

@ -0,0 +1,560 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
using namespace mindspore::dataset;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetBasic.";
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::vector<std::string> column_names = {"index", "title", "description"};
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 2 samples.
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsGetters.";
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 3);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(types[1].ToString(), "string");
EXPECT_EQ(types[2].ToString(), "string");
EXPECT_EQ(shapes.size(), 3);
EXPECT_EQ(shapes[0].ToString(), "<>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(shapes[2].ToString(), "<>");
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetFail.";
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::string invalid_csv_file = "./NotExistFile";
std::vector<std::string> column_names = {"index", "title", "description"};
std::shared_ptr<Dataset> ds0 = AGNews("", "test", 0);
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid AGNews input.
EXPECT_EQ(iter0, nullptr);
// Create a AGNews Dataset with invalid usage.
std::shared_ptr<Dataset> ds1 = AGNews(invalid_csv_file);
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid AGNews input.
EXPECT_EQ(iter1, nullptr);
// Test invalid num_samples < -1.
std::shared_ptr<Dataset> ds2 =
AGNews(dataset_dir, "test", -1, ShuffleMode::kFalse);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid AGNews input.
EXPECT_EQ(iter2, nullptr);
// Test invalid num_shards < 1.
std::shared_ptr<Dataset> ds3 =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid AGNews input.
EXPECT_EQ(iter3, nullptr);
// Test invalid shard_id >= num_shards.
std::shared_ptr<Dataset> ds4 =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds4, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
// Expect failure: invalid AGNews input.
EXPECT_EQ(iter4, nullptr);
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetNumSamples) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetNumSamples.";
// Create a AGNewsDataset, with single CSV file.
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "test", 2, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it..
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 2 samples.
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetDistribution) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetDistribution.";
// Create a AGNewsDataset, with single CSV file.
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 1 samples.
EXPECT_EQ(i, 1);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetMultiFiles) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetMultiFiles.";
// Create a AGNewsDataset, with single CSV file.
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "all", 0, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"3", "Demand analysis",
"\"Users simply click on the module they want to view to "
"browse information about that module.\""},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
{"3", "UML Timing Diagram",
"Information is mainly displayed using locally stored data and mapping, "
"which is not timely and does not have the ability to update itself."},
{"3", "In summary",
"This paper implements a map visualization system for Hangzhou city "
"information, using extensive knowledge of visualization techniques."},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 5 samples.
EXPECT_EQ(i, 5);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetHeader) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetHeader.";
// Create a AGNewsDataset, with single CSV file.
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 2 samples.
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesA.";
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers =
GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
<< ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "all", 0, ShuffleMode::kFiles);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Demand analysis",
"\"Users simply click on the module they want to view to "
"browse information about that module.\""},
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"3", "UML Timing Diagram",
"Information is mainly displayed using locally stored data and mapping, "
"which is not timely and does not have the ability to update itself."},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
{"3", "In summary",
"This paper implements a map visualization system for Hangzhou city "
"information, using extensive knowledge of visualization techniques."},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 5 samples.
EXPECT_EQ(i, 5);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(
original_num_parallel_workers);
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesB.";
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers =
GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
<< ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "all", 0, ShuffleMode::kInfile);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data."},
{"3", "Demand analysis",
"\"Users simply click on the module they want to view to "
"browse information about that module.\""},
{"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""},
{"3", "UML Timing Diagram",
"Information is mainly displayed using locally stored data and mapping, "
"which is not timely and does not have the ability to update itself."},
{"3", "In summary",
"This paper implements a map visualization system for Hangzhou city "
"information, using extensive knowledge of visualization techniques."},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length()
<< ", Text: " << ss.substr(0, 50);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 5 samples.
EXPECT_EQ(i, 5);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(
original_num_parallel_workers);
}
/// Feature: Test AGNewsDataset Dataset.
/// Description: read AGNewsDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleGlobal.";
// Test AGNews Dataset with GLOBLE shuffle.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers =
GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
<< ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(135);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
std::shared_ptr<Dataset> ds =
AGNews(dataset_dir, "train", 0, ShuffleMode::kGlobal);
std::vector<std::string> column_names = {"index", "title", "description"};
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("index"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"3", "UML Timing Diagram",
"Information is mainly displayed using locally stored data and mapping, "
"which is not timely and does not have the ability to update itself."},
{"3", "In summary",
"This paper implements a map visualization system for Hangzhou city "
"information, using extensive knowledge of visualization techniques."},
{"3", "Demand analysis",
"\"Users simply click on the module they want to view to "
"browse information about that module.\""},
};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples.
EXPECT_EQ(i, 3);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(
original_num_parallel_workers);
}

View File

@ -0,0 +1,2 @@
3,Background of the selection,"In this day and age, the internet is growing rapidly, the total number of connected devices is increasing and we are entering the era of big data."
4,Related technologies,"""Leaflet is the leading open source JavaScript library for mobile-friendly interactive maps."""
1 3 Background of the selection In this day and age, the internet is growing rapidly, the total number of connected devices is increasing and we are entering the era of big data.
2 4 Related technologies "Leaflet is the leading open source JavaScript library for mobile-friendly interactive maps."

View File

@ -0,0 +1,3 @@
3,Demand analysis,"""Users simply click on the module they want to view to browse information about that module."""
3,UML Timing Diagram,"Information is mainly displayed using locally stored data and mapping, which is not timely and does not have the ability to update itself."
3,In summary,"This paper implements a map visualization system for Hangzhou city information, using extensive knowledge of visualization techniques."
1 3 Demand analysis "Users simply click on the module they want to view to browse information about that module."
2 3 UML Timing Diagram Information is mainly displayed using locally stored data and mapping, which is not timely and does not have the ability to update itself.
3 3 In summary This paper implements a map visualization system for Hangzhou city information, using extensive knowledge of visualization techniques.

View File

@ -0,0 +1,163 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
FILE_DIR = '../data/dataset/testAGNews'
def test_ag_news_dataset_basic():
"""
Feature: Test AG News Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
buffer = []
data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False)
data = data.repeat(2)
data = data.skip(2)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 8
def test_ag_news_dataset_one_file():
"""
Feature: Test AG News Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
buffer = []
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 2
def test_ag_news_dataset_all_file():
"""
Feature: Test AG News Dataset(usage=all).
Description: read train data and test data.
Expectation: the data is processed successfully.
"""
buffer = []
data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 5
def test_ag_news_dataset_num_samples():
"""
Feature: Test AG News Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.AGNewsDataset(FILE_DIR, usage='all', num_samples=4, shuffle=False)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 4
def test_ag_news_dataset_distribution():
"""
Feature: Test AG News Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 1
def test_ag_news_dataset_quoted():
"""
Feature: Test get the AG News Dataset.
Description: read AGNewsDataset data and get data.
Expectation: the data is processed successfully.
"""
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
buffer = []
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.extend([d['index'].item().decode("utf8"),
d['title'].item().decode("utf8"),
d['description'].item().decode("utf8")])
assert buffer == ["3", "Background of the selection",
"In this day and age, the internet is growing rapidly, "
"the total number of connected devices is increasing and "
"we are entering the era of big data.",
"4", "Related technologies",
"\"Leaflet is the leading open source JavaScript library "
"for mobile-friendly interactive maps.\""]
def test_ag_news_dataset_size():
"""
Feature: Test Getters.
Description: test get_dataset_size of AG News dataset.
Expectation: the data is processed successfully.
"""
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
assert data.get_dataset_size() == 2
def test_ag_news_dataset_exception():
"""
Feature: Error Test.
Description: test the wrong input.
Expectation: unable to read in data.
"""
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
data = data.map(operations=exception_func, input_columns=["description"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
if __name__ == "__main__":
test_ag_news_dataset_basic()
test_ag_news_dataset_one_file()
test_ag_news_dataset_all_file()
test_ag_news_dataset_num_samples()
test_ag_news_dataset_distribution()
test_ag_news_dataset_quoted()
test_ag_news_dataset_size()
test_ag_news_dataset_exception()