forked from mindspore-Ecosystem/mindspore
!19986 [assistant][ops] Add new dataset loading operator AGNEWS
Merge pull request !19986 from 杨旭华/AGNEWS
This commit is contained in:
commit
b910870ecc
|
@ -83,6 +83,7 @@
|
|||
#include "minddata/dataset/util/services.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
|
@ -851,6 +852,14 @@ std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint6
|
|||
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
|
||||
return cache;
|
||||
}
|
||||
|
||||
AGNewsDataset::AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto ds = std::make_shared<AGNewsNode>(CharToString(dataset_dir), num_samples, shuffle, CharToString(usage),
|
||||
num_shards, shard_id, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
#endif
|
||||
|
||||
AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "minddata/dataset/util/path.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
|
@ -62,6 +63,18 @@ namespace dataset {
|
|||
// PYBIND FOR LEAF NODES
|
||||
// (In alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode",
|
||||
"to create an AGNewsNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id) {
|
||||
auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle),
|
||||
usage, num_shards, shard_id, nullptr);
|
||||
THROW_IF_ERROR(ag_news->ValidateParams());
|
||||
return ag_news;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
|
||||
"to create a CelebANode")
|
||||
|
|
|
@ -27,6 +27,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
places365_op.cc
|
||||
photo_tour_op.cc
|
||||
fashion_mnist_op.cc
|
||||
ag_news_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/ag_news_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
AGNewsOp::AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, const std::vector<std::string> &ag_news_list)
|
||||
: CsvOp(ag_news_list, field_delim, column_default, column_name, num_workers, num_samples, worker_connector_size,
|
||||
op_connector_size, shuffle_files, num_devices, device_id) {}
|
||||
|
||||
// A print method typically used for debugging.
|
||||
void AGNewsOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op.
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAGNews files list:\n";
|
||||
for (int i = 0; i < csv_files_list_.size(); ++i) {
|
||||
out << " " << csv_files_list_[i];
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/util/auto_index.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class JaggedConnector;
|
||||
|
||||
class AGNewsOp : public CsvOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] num_workers Number of workers reading images in parallel
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// (Default = 0 means all samples).
|
||||
/// \param[in] worker_connector_size Size of each internal queue.
|
||||
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
|
||||
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
|
||||
/// \param[in] num_devices Number of devices that the dataset should be divided into. (Default = 1)
|
||||
/// \param[in] device_id The device ID within num_devices. This argument should be
|
||||
/// specified only when num_devices is also specified (Default = 0).
|
||||
/// \param[in] field_delim A char that indicates the delimiter to separate fields (default=',').
|
||||
/// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is
|
||||
/// either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
|
||||
/// \param[in] column_name List of column names of the dataset (default={}). If this is not provided, infers the
|
||||
/// column_names from the first row of CSV file.
|
||||
/// \param[in] ag_news_list List of files to be read to search for a pattern of files. The list
|
||||
/// will be sorted in a lexicographical order.
|
||||
AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
|
||||
const std::vector<std::string> &ag_news_list);
|
||||
|
||||
/// \brief Default destructor.
|
||||
~AGNewsOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[in] out he output stream to write output to.
|
||||
/// \param[in] show_all A bool to control if you want to show all info or just a
|
||||
/// summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "AGNewsOp"; }
|
||||
|
||||
// DatasetName name getter
|
||||
// \return DatasetName of the current Op
|
||||
std::string DatasetName(bool upper = false) const { return upper ? "AGNews" : "ag news"; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_
|
|
@ -183,7 +183,7 @@ class CsvOp : public NonMappableLeafOp {
|
|||
// \return DatasetName of the current Op
|
||||
virtual std::string DatasetName(bool upper = false) const { return upper ? "CSV" : "csv"; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param tensor_table - the tensor table to put the parsed data in.
|
||||
|
|
|
@ -74,6 +74,7 @@ constexpr char kTransferNode[] = "Transfer";
|
|||
constexpr char kZipNode[] = "Zip";
|
||||
|
||||
// Names for leaf IR node
|
||||
constexpr char kAGNewsNode[] = "AGNewsDataset";
|
||||
constexpr char kAlbumNode[] = "AlbumDataset";
|
||||
constexpr char kCelebANode[] = "CelebADataset";
|
||||
constexpr char kCifar100Node[] = "Cifar100Dataset";
|
||||
|
|
|
@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
add_subdirectory(samplers)
|
||||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
ag_news_node.cc
|
||||
album_node.cc
|
||||
celeba_node.cc
|
||||
cifar100_node.cc
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/ag_news_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor for AGNewsNode.
|
||||
AGNewsNode::AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle,
|
||||
const std::string &usage, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache)
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
usage_(usage),
|
||||
ag_news_files_list_(WalkAllFiles(usage, dataset_dir)) {
|
||||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> AGNewsNode::Copy() {
|
||||
auto node =
|
||||
std::make_shared<AGNewsNode>(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void AGNewsNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
|
||||
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
|
||||
}
|
||||
|
||||
Status AGNewsNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("AGNewsNode", usage_, {"train", "test", "all"}));
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "AGNewsNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (num_shards_ < 1) {
|
||||
std::string err_msg = "AGNewsNode: Invalid number of shards: " + std::to_string(num_shards_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsNode", num_shards_, shard_id_));
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsNode", "column_names", column_names_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build AGNewsNode.
|
||||
Status AGNewsNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
// Sort the dataset files in a lexicographical order.
|
||||
std::vector<std::string> sorted_dataset_files = ag_news_files_list_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
// Because AGNews does not have external column_defaults nor column_names parameters,
|
||||
// they need to be set before AGNewsOp is initialized.
|
||||
// AGNews data set is formatted as three columns of data, so three columns are added.
|
||||
std::vector<std::shared_ptr<AGNewsOp::BaseRecord>> column_default;
|
||||
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
|
||||
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
|
||||
column_default.push_back(std::make_shared<CsvOp::Record<std::string>>(AGNewsOp::STRING, ""));
|
||||
std::vector<std::string> column_name = {"index", "title", "description"};
|
||||
// AGNews data values are always delimited by a comma.
|
||||
char field_delim_ = ',';
|
||||
std::shared_ptr<AGNewsOp> ag_news_op =
|
||||
std::make_shared<AGNewsOp>(num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
|
||||
num_shards_, shard_id_, field_delim_, column_default, column_name, sorted_dataset_files);
|
||||
RETURN_IF_NOT_OK(ag_news_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp.
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
// First, get the number of rows in the dataset.
|
||||
RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows));
|
||||
// Add the shuffle op after this op.
|
||||
RETURN_IF_NOT_OK(
|
||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
ag_news_op->SetTotalRepeats(GetTotalRepeats());
|
||||
ag_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(ag_news_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node.
|
||||
Status AGNewsNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size.
|
||||
Status AGNewsNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows));
|
||||
sample_size = num_samples_;
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AGNewsNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and
|
||||
// should be promoted to its parent class. AGNews (for which internally is based off CSV)
|
||||
// by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree,
|
||||
// that cache can inherit this sampler from the leaf, providing sampling support from
|
||||
// the caching layer.
|
||||
// Should be promoted to its parent class.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status AGNewsNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this AGNews node, then
|
||||
// the cache will be executing a sampler for fetching the data. As such, any
|
||||
// options in the AGNews node need to be reset to its defaults so that this
|
||||
// AGNews node will produce the full set of data into the cache.
|
||||
Status AGNewsNode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::string> AGNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
|
||||
std::vector<std::string> ag_news_files_list;
|
||||
Path train_prefix("train.csv");
|
||||
Path test_prefix("test.csv");
|
||||
Path dir(dataset_dir);
|
||||
|
||||
if (usage == "train") {
|
||||
Path temp_path = dir / train_prefix;
|
||||
ag_news_files_list.push_back(temp_path.ToString());
|
||||
} else if (usage == "test") {
|
||||
Path temp_path = dir / test_prefix;
|
||||
ag_news_files_list.push_back(temp_path.ToString());
|
||||
} else {
|
||||
Path temp_path = dir / train_prefix;
|
||||
ag_news_files_list.push_back(temp_path.ToString());
|
||||
Path temp_path1 = dir / test_prefix;
|
||||
ag_news_files_list.push_back(temp_path1.ToString());
|
||||
}
|
||||
return ag_news_files_list;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief class AGNewsNode.
|
||||
/// \brief Dataset derived class to represent AGNews dataset.
|
||||
class AGNewsNode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, const std::string &usage,
|
||||
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~AGNewsNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kAGNewsNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[in] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief A base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
int64_t NumSamples() const { return num_samples_; }
|
||||
ShuffleMode Shuffle() const { return shuffle_; }
|
||||
int32_t NumShards() const { return num_shards_; }
|
||||
int32_t ShardId() const { return shard_id_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief AGNews by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in
|
||||
/// the tree, that cache can inherit this sampler from the leaf, providing
|
||||
/// sampling support from the caching layer. That is why we setup the
|
||||
/// sampler for a leaf node that does not use sampling. Note: This
|
||||
/// function is common among NonMappableSourceNode and should be promoted
|
||||
/// to its parent class.
|
||||
/// \param[in] sampler The sampler to setup.
|
||||
/// \return Status of the function.
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this ag_news node,
|
||||
/// then the cache will be executing a sampler for fetching the data.
|
||||
/// As such, any options in the AGNews node need to be reset to its defaults
|
||||
/// so that this AGNews node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its
|
||||
/// parent class.
|
||||
/// \return Status of the function.
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
/// \brief Generate a list of read file names according to usage.
|
||||
/// \param[in] usage Part of dataset of AGNews.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \return std::vector<std::string> A list of read file names.
|
||||
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
char field_delim_;
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults_;
|
||||
std::vector<std::string> column_names_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
std::vector<std::string> ag_news_files_list_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_
|
|
@ -993,6 +993,56 @@ inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
|
|||
return SchemaCharIF(StringToChar(schema_file));
|
||||
}
|
||||
|
||||
/// \class AGNewsDataset
|
||||
/// \brief A source dataset that reads and parses AG News datasets.
|
||||
class AGNewsDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of AGNewsDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of data list csv file to be read, can be "train", "test" or "all".
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kFiles - Shuffle files only.
|
||||
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of AGNewsDataset.
|
||||
~AGNewsDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a AGNewsDataset.
|
||||
/// \note The generated dataset has three columns ['index', 'title', 'description'].
|
||||
/// The index range is [1, 4].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage One of "all", "train" or "test" (default = "all").
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// (Default = 0 means all samples).
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch
|
||||
/// (Default=ShuffleMode::kGlobal).
|
||||
/// Can be any of:
|
||||
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||
/// ShuffleMode::kFiles - Shuffle files only.
|
||||
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified (Default = 0).
|
||||
/// \param[in] cache Tensor cache to use.(default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the AGNewsDataset.
|
||||
inline std::shared_ptr<AGNewsDataset> AGNews(const std::string &dataset_dir, const std::string &usage = "all",
|
||||
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
|
||||
int32_t num_shards = 1, int32_t shard_id = 0,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<AGNewsDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
|
||||
num_shards, shard_id, cache);
|
||||
}
|
||||
|
||||
/// \class AlbumDataset
|
||||
/// \brief A source dataset for reading and parsing Album dataset.
|
||||
class AlbumDataset : public Dataset {
|
||||
|
|
|
@ -67,7 +67,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
|
||||
check_photo_tour_dataset
|
||||
check_photo_tour_dataset, check_ag_news_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -5075,6 +5075,93 @@ class ManifestDataset(MappableDataset):
|
|||
return self.class_indexing
|
||||
|
||||
|
||||
class AGNewsDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses AG News datasets.
|
||||
|
||||
The generated dataset has three columns: :py:obj:`[index, title, description]`.
|
||||
The tensor of column :py:obj:`index` is of the string type.
|
||||
The tensor of column :py:obj:`title` is of the string type.
|
||||
The tensor of column :py:obj:`description` is of the string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Acceptable usages include `train`, `test` and `all` (default=None, all samples).
|
||||
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
|
||||
(default=Shuffle.GLOBAL).
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Otherwise, there are two levels of shuffling:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||
|
||||
- Shuffle.FILES: Shuffle files only.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, 'num_samples' reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Examples:
|
||||
>>> ag_news_dataset_dir = "/path/to/ag_news_dataset_file"
|
||||
>>> dataset = ds.AGNewsDataset(dataset_dir=ag_news_dataset_dir, usage='all')
|
||||
|
||||
About AGNews dataset:
|
||||
|
||||
AG is a collection of over 1 million news articles. The news articles were collected
|
||||
by ComeToMyHead from over 2,000 news sources in over 1 year of activity. ComeToMyHead
|
||||
is an academic news search engine that has been in operation since July 2004.
|
||||
The dataset is provided by academics for research purposes such as data mining
|
||||
(clustering, classification, etc.), information retrieval (ranking, searching, etc.),
|
||||
xml, data compression, data streaming, and any other non-commercial activities.
|
||||
AG's news topic classification dataset was constructed by selecting the four largest
|
||||
classes from the original corpus. Each class contains 30,000 training samples and
|
||||
1,900 test samples. The total number of training samples in train.csv is 120,000
|
||||
and the number of test samples in test.csv is 7,600.
|
||||
|
||||
You can unzip the dataset files into the following structure and read by MindSpore's API:
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── ag_news_dataset_dir
|
||||
├── classes.txt
|
||||
├── train.csv
|
||||
├── test.csv
|
||||
└── readme.txt
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@misc{zhang2015characterlevel,
|
||||
title={Character-level Convolutional Networks for Text Classification},
|
||||
author={Xiang Zhang and Junbo Zhao and Yann LeCun},
|
||||
year={2015},
|
||||
eprint={1509.01626},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_ag_news_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.AGNewsNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
|
||||
self.shard_id)
|
||||
|
||||
|
||||
class Cifar10Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Cifar10 dataset.
|
||||
|
|
|
@ -535,7 +535,7 @@ def check_generatordataset(method):
|
|||
raise ValueError("Neither columns_names nor schema are provided.")
|
||||
|
||||
if schema is not None:
|
||||
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
|
||||
if not isinstance(schema, (datasets.Schema, str)):
|
||||
raise ValueError("schema should be a path to schema file or a schema object.")
|
||||
|
||||
# check optional argument
|
||||
|
@ -1728,3 +1728,33 @@ def check_fake_image_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_ag_news_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
|
||||
# check dataset_files; required argument
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
# check usage
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
@ -14,6 +14,7 @@ SET(DE_UT_SRCS
|
|||
c_api_audio_a_to_q_test.cc
|
||||
c_api_audio_r_to_z_test.cc
|
||||
c_api_cache_test.cc
|
||||
c_api_dataset_ag_news_test.cc
|
||||
c_api_dataset_album_test.cc
|
||||
c_api_dataset_cifar_test.cc
|
||||
c_api_dataset_cityscapes_test.cc
|
||||
|
|
|
@ -0,0 +1,560 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetBasic.";
|
||||
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 2 samples.
|
||||
EXPECT_EQ(i, 2);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsGetters.";
|
||||
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
EXPECT_EQ(types.size(), 3);
|
||||
EXPECT_EQ(types[0].ToString(), "string");
|
||||
EXPECT_EQ(types[1].ToString(), "string");
|
||||
EXPECT_EQ(types[2].ToString(), "string");
|
||||
EXPECT_EQ(shapes.size(), 3);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[2].ToString(), "<>");
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetFail.";
|
||||
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::string invalid_csv_file = "./NotExistFile";
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
std::shared_ptr<Dataset> ds0 = AGNews("", "test", 0);
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid AGNews input.
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
// Create a AGNews Dataset with invalid usage.
|
||||
std::shared_ptr<Dataset> ds1 = AGNews(invalid_csv_file);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid AGNews input.
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
// Test invalid num_samples < -1.
|
||||
std::shared_ptr<Dataset> ds2 =
|
||||
AGNews(dataset_dir, "test", -1, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid AGNews input.
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
// Test invalid num_shards < 1.
|
||||
std::shared_ptr<Dataset> ds3 =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid AGNews input.
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
// Test invalid shard_id >= num_shards.
|
||||
std::shared_ptr<Dataset> ds4 =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2);
|
||||
EXPECT_NE(ds4, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
|
||||
// Expect failure: invalid AGNews input.
|
||||
EXPECT_EQ(iter4, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetNumSamples) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetNumSamples.";
|
||||
|
||||
// Create a AGNewsDataset, with single CSV file.
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "test", 2, ShuffleMode::kFalse);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it..
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 2 samples.
|
||||
EXPECT_EQ(i, 2);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetDistribution) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetDistribution.";
|
||||
|
||||
// Create a AGNewsDataset, with single CSV file.
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 1 samples.
|
||||
EXPECT_EQ(i, 1);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetMultiFiles) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetMultiFiles.";
|
||||
|
||||
// Create a AGNewsDataset, with single CSV file.
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "all", 0, ShuffleMode::kFalse);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"3", "Demand analysis",
|
||||
"\"Users simply click on the module they want to view to "
|
||||
"browse information about that module.\""},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
{"3", "UML Timing Diagram",
|
||||
"Information is mainly displayed using locally stored data and mapping, "
|
||||
"which is not timely and does not have the ability to update itself."},
|
||||
{"3", "In summary",
|
||||
"This paper implements a map visualization system for Hangzhou city "
|
||||
"information, using extensive knowledge of visualization techniques."},
|
||||
};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 5 samples.
|
||||
EXPECT_EQ(i, 5);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetHeader) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetHeader.";
|
||||
|
||||
// Create a AGNewsDataset, with single CSV file.
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 2 samples.
|
||||
EXPECT_EQ(i, 2);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleFilesA) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesA.";
|
||||
|
||||
// Set configuration.
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers =
|
||||
GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
|
||||
<< ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "all", 0, ShuffleMode::kFiles);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Demand analysis",
|
||||
"\"Users simply click on the module they want to view to "
|
||||
"browse information about that module.\""},
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"3", "UML Timing Diagram",
|
||||
"Information is mainly displayed using locally stored data and mapping, "
|
||||
"which is not timely and does not have the ability to update itself."},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
{"3", "In summary",
|
||||
"This paper implements a map visualization system for Hangzhou city "
|
||||
"information, using extensive knowledge of visualization techniques."},
|
||||
};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 5 samples.
|
||||
EXPECT_EQ(i, 5);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
// Restore configuration.
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(
|
||||
original_num_parallel_workers);
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleFilesB) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesB.";
|
||||
// Set configuration.
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers =
|
||||
GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
|
||||
<< ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "all", 0, ShuffleMode::kInfile);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data."},
|
||||
{"3", "Demand analysis",
|
||||
"\"Users simply click on the module they want to view to "
|
||||
"browse information about that module.\""},
|
||||
{"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""},
|
||||
{"3", "UML Timing Diagram",
|
||||
"Information is mainly displayed using locally stored data and mapping, "
|
||||
"which is not timely and does not have the ability to update itself."},
|
||||
{"3", "In summary",
|
||||
"This paper implements a map visualization system for Hangzhou city "
|
||||
"information, using extensive knowledge of visualization techniques."},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length()
|
||||
<< ", Text: " << ss.substr(0, 50);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 5 samples.
|
||||
EXPECT_EQ(i, 5);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
// Restore configuration.
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(
|
||||
original_num_parallel_workers);
|
||||
}
|
||||
|
||||
/// Feature: Test AGNewsDataset Dataset.
|
||||
/// Description: read AGNewsDataset data and get data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestAGNewsDatasetShuffleGlobal) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleGlobal.";
|
||||
// Test AGNews Dataset with GLOBLE shuffle.
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers =
|
||||
GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed
|
||||
<< ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
std::string dataset_dir = datasets_root_path_ + "/testAGNews";
|
||||
std::shared_ptr<Dataset> ds =
|
||||
AGNews(dataset_dir, "train", 0, ShuffleMode::kGlobal);
|
||||
std::vector<std::string> column_names = {"index", "title", "description"};
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
EXPECT_NE(row.find("index"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"3", "UML Timing Diagram",
|
||||
"Information is mainly displayed using locally stored data and mapping, "
|
||||
"which is not timely and does not have the ability to update itself."},
|
||||
{"3", "In summary",
|
||||
"This paper implements a map visualization system for Hangzhou city "
|
||||
"information, using extensive knowledge of visualization techniques."},
|
||||
{"3", "Demand analysis",
|
||||
"\"Users simply click on the module they want to view to "
|
||||
"browse information about that module.\""},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::shared_ptr<Tensor> de_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||
std::string_view sv;
|
||||
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||
std::string ss(sv);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
// Expect 3 samples.
|
||||
EXPECT_EQ(i, 3);
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
// Restore configuration.
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(
|
||||
original_num_parallel_workers);
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
3,Background of the selection,"In this day and age, the internet is growing rapidly, the total number of connected devices is increasing and we are entering the era of big data."
|
||||
4,Related technologies,"""Leaflet is the leading open source JavaScript library for mobile-friendly interactive maps."""
|
|
|
@ -0,0 +1,3 @@
|
|||
3,Demand analysis,"""Users simply click on the module they want to view to browse information about that module."""
|
||||
3,UML Timing Diagram,"Information is mainly displayed using locally stored data and mapping, which is not timely and does not have the ability to update itself."
|
||||
3,In summary,"This paper implements a map visualization system for Hangzhou city information, using extensive knowledge of visualization techniques."
|
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
|
||||
FILE_DIR = '../data/dataset/testAGNews'
|
||||
|
||||
|
||||
def test_ag_news_dataset_basic():
|
||||
"""
|
||||
Feature: Test AG News Dataset.
|
||||
Description: read data from a single file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
buffer = []
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False)
|
||||
data = data.repeat(2)
|
||||
data = data.skip(2)
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 8
|
||||
|
||||
|
||||
def test_ag_news_dataset_one_file():
|
||||
"""
|
||||
Feature: Test AG News Dataset.
|
||||
Description: read data from a single file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
buffer = []
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 2
|
||||
|
||||
|
||||
def test_ag_news_dataset_all_file():
|
||||
"""
|
||||
Feature: Test AG News Dataset(usage=all).
|
||||
Description: read train data and test data.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
buffer = []
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False)
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 5
|
||||
|
||||
|
||||
def test_ag_news_dataset_num_samples():
|
||||
"""
|
||||
Feature: Test AG News Dataset.
|
||||
Description: read data from a single file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='all', num_samples=4, shuffle=False)
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == 4
|
||||
|
||||
|
||||
def test_ag_news_dataset_distribution():
|
||||
"""
|
||||
Feature: Test AG News Dataset.
|
||||
Description: read data from a single file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0)
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_ag_news_dataset_quoted():
|
||||
"""
|
||||
Feature: Test get the AG News Dataset.
|
||||
Description: read AGNewsDataset data and get data.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
buffer = []
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.extend([d['index'].item().decode("utf8"),
|
||||
d['title'].item().decode("utf8"),
|
||||
d['description'].item().decode("utf8")])
|
||||
assert buffer == ["3", "Background of the selection",
|
||||
"In this day and age, the internet is growing rapidly, "
|
||||
"the total number of connected devices is increasing and "
|
||||
"we are entering the era of big data.",
|
||||
"4", "Related technologies",
|
||||
"\"Leaflet is the leading open source JavaScript library "
|
||||
"for mobile-friendly interactive maps.\""]
|
||||
|
||||
|
||||
def test_ag_news_dataset_size():
|
||||
"""
|
||||
Feature: Test Getters.
|
||||
Description: test get_dataset_size of AG News dataset.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
assert data.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_ag_news_dataset_exception():
|
||||
"""
|
||||
Feature: Error Test.
|
||||
Description: test the wrong input.
|
||||
Expectation: unable to read in data.
|
||||
"""
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["description"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ag_news_dataset_basic()
|
||||
test_ag_news_dataset_one_file()
|
||||
test_ag_news_dataset_all_file()
|
||||
test_ag_news_dataset_num_samples()
|
||||
test_ag_news_dataset_distribution()
|
||||
test_ag_news_dataset_quoted()
|
||||
test_ag_news_dataset_size()
|
||||
test_ag_news_dataset_exception()
|
Loading…
Reference in New Issue