!22086 [assistant][ops]New operator implementation, include ImdbDataset

Merge pull request !22086 from ZJUTER0126/ImdbDataset
This commit is contained in:
i-robot 2021-12-30 07:09:07 +00:00 committed by Gitee
commit 9850307d02
25 changed files with 1920 additions and 2 deletions

View File

@ -106,6 +106,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
@ -1266,6 +1267,30 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of IMDBDataset.
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of IMDBDataset.
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of IMDBDataset.
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
IWSLT2016Dataset::IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &language_pair,
const std::vector<char> &valid_set, const std::vector<char> &test_set,

View File

@ -44,6 +44,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
@ -316,6 +317,16 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
(void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode",
"to create an IMDBNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(imdb->ValidateParams());
return imdb;
}));
}));
PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
(void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>(
*m, "IWSLT2016Node", "to create an IWSLT2016Node")

View File

@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
fashion_mnist_op.cc
flickr_op.cc
image_folder_op.cc
imdb_op.cc
iwslt_op.cc
io_block.cc
kmnist_op.cc

View File

@ -0,0 +1,232 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/imdb_op.h"
#include <fstream>
#include <unordered_set>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
constexpr int32_t kNumClasses = 2;
IMDBOp::IMDBOp(int32_t num_workers, const std::string &file_dir, int32_t queue_size, const std::string &usage,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(std::move(file_dir)),
usage_(usage),
data_schema_(std::move(data_schema)),
sampler_ind_(0) {}
Status IMDBOp::PrepareData() {
std::vector<std::string> usage_list;
if (usage_ == "all") {
usage_list.push_back("train");
usage_list.push_back("test");
} else {
usage_list.push_back(usage_);
}
std::vector<std::string> label_list = {"pos", "neg"};
// get abs path for folder_path_
auto realpath = FileUtils::GetRealPath(folder_path_.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, imdb dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, imdb dataset dir: " + folder_path_ + " does not exist.");
}
Path base_dir(realpath.value());
for (auto usage : usage_list) {
for (auto label : label_list) {
Path dir = base_dir / usage / label;
RETURN_IF_NOT_OK(GetDataByUsage(dir.ToString(), label));
}
}
text_label_pairs_.shrink_to_fit();
num_rows_ = text_label_pairs_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
DatasetName() + " file path: " + folder_path_);
}
return Status::OK();
}
// Load 1 TensorRow (text, label) using 1 std::pair<std::string, int32_t>. 1 function call produces 1 TensorTow
Status IMDBOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
std::pair<std::string, int32_t> pair_ptr = text_label_pairs_[row_id];
std::shared_ptr<Tensor> text, label;
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr.second, &label));
RETURN_IF_NOT_OK(LoadFile(pair_ptr.first, &text));
(*trow) = TensorRow(row_id, {std::move(text), std::move(label)});
trow->setPath({pair_ptr.first, std::string("")});
return Status::OK();
}
void IMDBOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\n"
<< DatasetName(true) << " directory: " << folder_path_ << "\nUsage: " << usage_ << "\n\n";
}
}
// Derived from RandomAccessOp
Status IMDBOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || text_label_pairs_.empty()) {
if (text_label_pairs_.empty()) {
RETURN_STATUS_UNEXPECTED("Invalid dataset dir, " + DatasetName(true) +
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
DatasetName() + " file path: " + folder_path_);
} else {
RETURN_STATUS_UNEXPECTED(
"[Internal ERROR], Map containing text-index pair is nullptr or has been set in other place, "
"it must be empty before using GetClassIds.");
}
}
for (size_t i = 0; i < text_label_pairs_.size(); ++i) {
(*cls_ids)[text_label_pairs_[i].second].push_back(i);
}
for (auto &pair : (*cls_ids)) {
pair.second.shrink_to_fit();
}
return Status::OK();
}
Status IMDBOp::GetDataByUsage(const std::string &folder, const std::string &label) {
Path dir_usage_label(folder);
if (!dir_usage_label.Exists() || !dir_usage_label.IsDirectory()) {
RETURN_STATUS_UNEXPECTED("Invalid parameter, dataset dir may not exist or is not a directory: " + folder);
}
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir_usage_label);
CHECK_FAIL_RETURN_UNEXPECTED(dir_itr != nullptr,
"Invalid path, failed to open imdb dir: " + folder + ", permission denied.");
std::map<std::string, int32_t> text_label_map;
while (dir_itr->HasNext()) {
Path file = dir_itr->Next();
text_label_map[file.ToString()] = (label == "pos") ? 1 : 0;
}
for (auto item : text_label_map) {
text_label_pairs_.emplace_back(std::make_pair(item.first, item.second));
}
return Status::OK();
}
Status IMDBOp::CountRows(const std::string &path, const std::string &usage, int64_t *num_rows) {
RETURN_UNEXPECTED_IF_NULL(num_rows);
// get abs path for folder_path_
auto abs_path = FileUtils::GetRealPath(path.data());
if (!abs_path.has_value()) {
MS_LOG(ERROR) << "Invalid file path, imdb dataset dir: " << path << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, imdb dataset dir: " + path + " does not exist.");
}
Path data_dir(abs_path.value());
std::vector<std::string> all_dirs_list = {"pos", "neg"};
std::vector<std::string> usage_list;
if (usage == "all") {
usage_list.push_back("train");
usage_list.push_back("test");
} else {
usage_list.push_back(usage);
}
int64_t row_cnt = 0;
for (int32_t ind = 0; ind < usage_list.size(); ++ind) {
Path texts_dir_usage_path = data_dir / usage_list[ind];
CHECK_FAIL_RETURN_UNEXPECTED(
texts_dir_usage_path.Exists() && texts_dir_usage_path.IsDirectory(),
"Invalid path, dataset path may not exist or is not a directory: " + texts_dir_usage_path.ToString());
for (auto dir : all_dirs_list) {
Path texts_dir_usage_dir_path((texts_dir_usage_path / dir).ToString());
std::shared_ptr<Path::DirIterator> dir_iter = Path::DirIterator::OpenDirectory(&texts_dir_usage_dir_path);
CHECK_FAIL_RETURN_UNEXPECTED(dir_iter != nullptr,
"Invalid path, failed to open imdb dir: " + path + ", permission denied.");
RETURN_UNEXPECTED_IF_NULL(dir_iter);
while (dir_iter->HasNext()) {
row_cnt++;
}
}
}
(*num_rows) = row_cnt;
return Status::OK();
}
Status IMDBOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
// Get number of classes
Status IMDBOp::GetNumClasses(int64_t *num_classes) {
RETURN_UNEXPECTED_IF_NULL(num_classes);
*num_classes = kNumClasses;
return Status::OK();
}
Status IMDBOp::LoadFile(const std::string &file, std::shared_ptr<Tensor> *out_row) {
RETURN_UNEXPECTED_IF_NULL(out_row);
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
}
std::string line;
// IMDB just have a line for every txt.
while (getline(handle, line)) {
if (line.empty()) {
continue;
}
auto rc = LoadTensor(line, out_row);
if (rc.IsError()) {
handle.close();
return rc;
}
}
handle.close();
return Status::OK();
}
Status IMDBOp::LoadTensor(const std::string &line, std::shared_ptr<Tensor> *out_row) {
RETURN_UNEXPECTED_IF_NULL(out_row);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
*out_row = std::move(tensor);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_
#include <algorithm>
#include <deque>
#include <map>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// Forward declares
template <typename T>
class Queue;
class IMDBOp : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] int32_t num_workers - num of workers reading texts in parallel.
/// \param[in] std::string dataset_dir - dir directory of IMDB dataset.
/// \param[in] int32_t queue_size - connector queue size.
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] DataSchema data_schema - the schema of each column in output data.
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells Folder what to read.
IMDBOp(int32_t num_workers, const std::string &dataset_dir, int32_t queue_size, const std::string &usage,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor.
~IMDBOp() = default;
/// \brief Parse IMDB data.
/// \return Status - The status code returned.
Status PrepareData() override;
/// \brief Method derived from RandomAccess Op, enable Sampler to get all ids for each class
/// \param[in] map cls_ids - key label, val all ids for this class
/// \return Status - The status code returned.
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
/// \brief A print method typically used for debugging.
/// \param[out] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief This function return the num_rows.
/// \param[in] std::string path - dir directory of IMDB dataset.
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[out] int64_t *num_rows - output arg that will hold the actual dataset size.
/// \return Status - The status code returned.
static Status CountRows(const std::string &path, const std::string &usage, int64_t *num_rows);
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "IMDBOp"; }
/// \brief Dataset name getter.
/// \param[in] upper Whether to get upper name.
/// \return Dataset name of the current Op.
virtual std::string DatasetName(bool upper = false) const { return upper ? "IMDB" : "imdb"; }
/// \brief Base-class override for GetNumClasses
/// \param[out] int64_t *num_classes - the number of classes
/// \return Status - The status code returned.
Status GetNumClasses(int64_t *num_classes) override;
private:
/// \brief Load a tensor row according to a pair.
/// \param[in] uint64_t row_id - row_id need to load.
/// \param[out] TensorRow *row - text & task read into this tensor row.
/// \return Status - The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
/// \brief Parses a single row and puts the data into a tensor table.
/// \param[in] string line - the content of the row.
/// \param[out] Tensor *out_row - the id of the row filled in the tensor table.
/// \return Status - The status code returned.
Status LoadTensor(const std::string &line, std::shared_ptr<Tensor> *out_row);
/// \brief Reads a text file and loads the data into Tensor.
/// \param[in] string file - the file to read.
/// \param[out] Tensor *out_row - the id of the row filled in the tensor table.
/// \return Status - The status code returned.
Status LoadFile(const std::string &file, std::shared_ptr<Tensor> *out_row);
/// \brief Called first when function is called
/// \param[in] string folder - the folder include files.
/// \param[in] string label - the name of label.
/// \return Status - The status code returned.
Status GetDataByUsage(const std::string &folder, const std::string &label);
/// \brief function for computing the assignment of the column name map.
/// \return Status - The status code returned.
Status ComputeColMap() override;
std::string folder_path_; // directory of text folder
std::string usage_;
int64_t sampler_ind_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::pair<std::string, int32_t>> text_label_pairs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_

View File

@ -96,6 +96,7 @@ constexpr char kFashionMnistNode[] = "FashionMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kIMDBNode[] = "IMDBDataset";
constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset";
constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset";
constexpr char kKMnistNode[] = "KMnistDataset";

View File

@ -22,6 +22,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
fashion_mnist_node.cc
flickr_node.cc
image_folder_node.cc
imdb_node.cc
iwslt2016_node.cc
iwslt2017_node.cc
kmnist_node.cc

View File

@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/imdb_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
IMDBNode::IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache = nullptr)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler), usage_(usage) {}
std::shared_ptr<DatasetNode> IMDBNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_);
return node;
}
void IMDBNode::Print(std::ostream &out) const {
out << (Name() + "(path: " + dataset_dir_ + ", usage: " + usage_ + ")");
}
Status IMDBNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("IMDBDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("IMDBDataset", usage_, {"train", "test", "all"}));
RETURN_IF_NOT_OK(ValidateDatasetSampler("IMDBDataset", sampler_));
return Status::OK();
}
Status IMDBNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_UNEXPECTED_IF_NULL(node_ops);
// Do internal Schema generation.
// This arg is exist in IMDBOp, but not externalized (in Python API).
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<IMDBOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, std::move(schema),
std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status IMDBNode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status IMDBNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(IMDBOp::CountRows(dataset_dir_, usage_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status IMDBNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_UNEXPECTED_IF_NULL(ds);
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode));
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<IMDBNode>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class IMDBNode
/// \brief A Dataset derived class to represent IMDB dataset
class IMDBNode : public MappableSourceNode {
public:
/// \brief Constructor.
/// \param[in] std::string dataset_dir - dir directory of IMDB dataset.
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells Folder what to read.
/// \param[in] cache Tensor cache to use.
IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~IMDBNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kIMDBNode; }
/// \brief Print the description.
/// \param[out] ostream out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[out] DatasetOp *node_ops A vector containing shared pointer to the Dataset Ops that this object will
/// create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \param[out] int32_t *shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] DatasetSizeGetter size_getter Shared pointer to DatasetSizeGetter
/// \param[in] bool estimate This is only supported by some of the ops and it's used to speed up the process of
/// getting dataset size at the expense of accuracy.
/// \param[out] int64_t *dataset_size The size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
/// \brief Get the arguments of node
/// \param[out] json *out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter
/// \param[in] sampler Tells IMDBOp what to read.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
private:
std::string dataset_dir_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_

View File

@ -2727,6 +2727,94 @@ inline std::shared_ptr<ImageFolderDataset> MS_API ImageFolder(const std::string
MapStringToChar(class_indexing), cache);
}
/// \class IMDBDataset
/// \brief A source dataset for reading and parsing IMDB dataset.
class MS_API IMDBDataset : public Dataset {
public:
/// \brief Constructor of IMDBDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of IMDBDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of IMDBDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of IMDBDataset.
~IMDBDataset() = default;
};
/// \brief A source dataset for reading and parsing IMDB dataset.
/// \note The generated dataset has two columns ["text", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all"
/// (Default="all").
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the IMDBDataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string dataset_path = "/path/to/imdb_dataset_directory";
/// std::shared_ptr<Dataset> ds = IMDB(dataset_path, "all");
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In IMDB dataset, each data dictionary has keys "text" and "label" */
/// auto text = row["text"];
/// \endcode
inline std::shared_ptr<IMDBDataset> MS_API
IMDB(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \brief A source dataset for reading and parsing IMDB dataset.
/// \note The generated dataset has two columns ["text", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the IMDBDataset.
inline std::shared_ptr<IMDBDataset> MS_API IMDB(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \brief A source dataset for reading and parsing IMDB dataset.
/// \note The generated dataset has two columns ["text", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the IMDBDataset.
inline std::shared_ptr<IMDBDataset> MS_API IMDB(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \class IWSLT2016Dataset.
/// \brief A source dataset for reading and parsing IWSLT2016 dataset.
class MS_API IWSLT2016Dataset : public Dataset {
@ -4890,7 +4978,7 @@ class MS_API WikiTextDataset : public Dataset {
/// iter->GetNextRow(&row);
///
/// /* Note: In WikiText dataset, each dictionary has key "text" */
/// auto text = row["image"];
/// auto text = row["text"];
/// \endcode
inline std::shared_ptr<WikiTextDataset> MS_API WikiText(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,

View File

@ -47,6 +47,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
friend class FashionMnistDataset;
friend class FlickrDataset;
friend class ImageFolderDataset;
friend class IMDBDataset;
friend class KMnistDataset;
friend class LJSpeechDataset;
friend class ManifestDataset;

View File

@ -76,7 +76,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset, \
check_conll2000_dataset, check_amazon_review_dataset, check_semeion_dataset, check_caltech101_dataset, \
check_caltech256_dataset, check_wiki_text_dataset
check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -3670,6 +3670,147 @@ class ImageFolderDataset(MappableDataset):
return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing)
class IMDBDataset(MappableDataset):
"""
A source dataset for reading and parsing Internet Movie Database (IMDb).
The generated dataset has two columns: :py:obj:`[text, label]`.
The tensor of column :py:obj:`text` is of the string type.
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test` or `all`
(default=None, will read all samples).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will read all samples).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- The shape of the test column.
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> imdb_dataset_dir = "/path/to/imdb_dataset_directory"
>>>
>>> # 1) Read all samples (text files) in imdb_dataset_dir with 8 threads
>>> dataset = ds.IMDBDataset(dataset_dir=imdb_dataset_dir, num_parallel_workers=8)
>>>
>>> # 2) Read train samples (text files).
>>> dataset = ds.IMDBDataset(dataset_dir=imdb_dataset_dir, usage="train")
About IMDBDataset:
The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The data set
was divided into 25 000 comments for training and 25 000 comments for testing, with both the training set and test
set containing 50% positive and 50% negative comments. Train labels and test labels are all lists of 0 and 1, where
0 stands for negative and 1 for positive.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
imdb_dataset_directory
train
pos
0_9.txt
1_7.txt
...
neg
0_3.txt
1_1.txt
...
test
pos
0_10.txt
1_10.txt
...
neg
0_2.txt
1_3.txt
...
Citation:
.. code-block::
@InProceedings{maas-EtAl:2011:ACL-HLT2011,
author = {Maas, Andrew L. and Daly, Raymond E. and Pham, Peter T. and Huang, Dan
and Ng, Andrew Y. and Potts, Christopher},
title = {Learning Word Vectors for Sentiment Analysis},
booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics:
Human Language Technologies},
month = {June},
year = {2011},
address = {Portland, Oregon, USA},
publisher = {Association for Computational Linguistics},
pages = {142--150},
url = {http://www.aclweb.org/anthology/P11-1015}
}
"""
@check_imdb_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.IMDBNode(self.dataset_dir, self.usage, self.sampler)
class IWSLT2016Dataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses IWSLT2016 datasets.

View File

@ -63,6 +63,35 @@ def check_imagefolderdataset(method):
return new_method
def check_imdb_dataset(method):
"""A wrapper that wraps a parameter checker around the original IMDBDataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
return method(self, *args, **kwargs)
return new_method
def check_iwslt2016_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset)."""

View File

@ -31,6 +31,7 @@ SET(DE_UT_SRCS
c_api_dataset_fake_image_test.cc
c_api_dataset_fashion_mnist_test.cc
c_api_dataset_flickr_test.cc
c_api_dataset_imdb_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_iwslt_test.cc
c_api_dataset_kmnist_test.cc

View File

@ -0,0 +1,260 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data and get all data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBBasic.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "all"; // 'train', 'test', 'all'
// Create a IMDB Dataset
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto text = row["text"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 8);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data and get train data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBTrain) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBTrain.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "train"; // 'train', 'test', 'all'
// Create a IMDB Dataset
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto text = row["text"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data and get test data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBTest) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBTest.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "test"; // 'train', 'test', 'all'
// Create a IMDB Dataset
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto text = row["text"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data and test pipeline.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBBasicWithPipeline.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "all"; // 'train', 'test', 'all'
// Create two IMDB Dataset
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, usage);
std::shared_ptr<Dataset> ds2 = IMDB(dataset_path, usage);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 3;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 2;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto text = row["text"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 40);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data with GetDatasetSize, GetColumnNames, GetBatchSize.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBGetters.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "all"; // 'train', 'test', 'all'
// Create a IMDB Dataset
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, usage);
std::vector<std::string> column_names = {"text", "label"};
std::vector<DataType> types = ToDETypes(ds1->GetOutputTypes());
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(types[1].ToString(), "int32");
EXPECT_NE(ds1, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 8);
EXPECT_EQ(ds1->GetColumnNames(), column_names);
EXPECT_EQ(ds1->GetBatchSize(), 1);
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data with errors.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBError.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "all"; // 'train', 'test', 'all'
// Create a IMDB Dataset with non-existing dataset dir
std::shared_ptr<Dataset> ds0 = IMDB("NotExistDir", usage);
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid IMDB input
EXPECT_EQ(iter0, nullptr);
// Create a IMDB Dataset with err usage
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, "val");
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid IMDB input
EXPECT_EQ(iter1, nullptr);
}
/// Feature: Test IMDB Dataset.
/// Description: read IMDB data with Null SamplerError.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestIMDBWithNullSamplerError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBWithNullSamplerError.";
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
std::string usage = "all";
// Create a IMDB Dataset
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid IMDB input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1 @@
test_neg_0.txt

View File

@ -0,0 +1 @@
test_neg_1.txt

View File

@ -0,0 +1 @@
test_pos_0.txt

View File

@ -0,0 +1 @@
test_pos_1.txt

View File

@ -0,0 +1 @@
train_neg_0.txt

View File

@ -0,0 +1 @@
train_neg_1.txt

View File

@ -0,0 +1 @@
train_pos_0.txt

View File

@ -0,0 +1 @@
train_pos_1.txt

View File

@ -0,0 +1 @@
train_unsup_0.txt

View File

@ -0,0 +1,732 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = "../data/dataset/testIMDBDataset"
def test_imdb_basic():
"""
Feature: Test IMDB Dataset.
Description: read data from all file.
Expectation: the data is processed successfully.
"""
logger.info("Test Case basic")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 8
content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt",
"test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
label = [1, 1, 0, 0, 1, 1, 0, 0]
num_iter = 0
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
# each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
strs = item["text"].item().decode("utf8")
logger.info("text is {}".format(strs))
logger.info("label is {}".format(item["label"]))
assert strs == content[index]
assert label[index] == int(item["label"])
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_imdb_test():
"""
Feature: Test IMDB Dataset.
Description: read data from test file.
Expectation: the data is processed successfully.
"""
logger.info("Test Case test")
# define parameters
repeat_count = 1
usage = "test"
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4
content = ["test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
label = [1, 1, 0, 0]
num_iter = 0
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
# each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
strs = item["text"].item().decode("utf8")
logger.info("text is {}".format(strs))
logger.info("label is {}".format(item["label"]))
assert strs == content[index]
assert label[index] == int(item["label"])
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_imdb_train():
"""
Feature: Test IMDB Dataset.
Description: read data from train file.
Expectation: the data is processed successfully.
"""
logger.info("Test Case train")
# define parameters
repeat_count = 1
usage = "train"
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4
content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt"]
label = [1, 1, 0, 0]
num_iter = 0
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
# each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
strs = item["text"].item().decode("utf8")
logger.info("text is {}".format(strs))
logger.info("label is {}".format(item["label"]))
assert strs == content[index]
assert label[index] == int(item["label"])
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_imdb_num_samples():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with num_samples=10 and num_parallel_workers=2.
Expectation: the data is processed successfully.
"""
logger.info("Test Case numSamples")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, num_samples=6, num_parallel_workers=2)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
def test_imdb_num_shards():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with num_shards=2 and shard_id=1.
Expectation: the data is processed successfully.
"""
logger.info("Test Case numShards")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=1)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_imdb_shard_id():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with num_shards=4 and shard_id=1.
Expectation: the data is processed successfully.
"""
logger.info("Test Case withShardID")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=0)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_imdb_no_shuffle():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with shuffle=False.
Expectation: the data is processed successfully.
"""
logger.info("Test Case noShuffle")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 8
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_imdb_true_shuffle():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with shuffle=True.
Expectation: the data is processed successfully.
"""
logger.info("Test Case extraShuffle")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, shuffle=True)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 16
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 16
def test_random_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with sampler=ds.RandomSampler().
Expectation: the data is processed successfully.
"""
logger.info("Test Case RandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.RandomSampler()
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 8
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_distributed_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with sampler=ds.DistributedSampler().
Expectation: the data is processed successfully.
"""
logger.info("Test Case DistributedSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.DistributedSampler(4, 1)
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 2
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_pk_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with sampler=ds.PKSampler().
Expectation: the data is processed successfully.
"""
logger.info("Test Case PKSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.PKSampler(3)
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_subset_random_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with sampler=ds.SubsetRandomSampler().
Expectation: the data is processed successfully.
"""
logger.info("Test Case SubsetRandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
indices = [0, 3, 1, 2, 5, 4]
sampler = ds.SubsetRandomSampler(indices)
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_weighted_random_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with sampler=ds.WeightedRandomSampler().
Expectation: the data is processed successfully.
"""
logger.info("Test Case WeightedRandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
sampler = ds.WeightedRandomSampler(weights, 6)
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_weighted_random_sampler_exception():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with random sampler exception.
Expectation: the data is processed successfully.
"""
logger.info("Test error cases for WeightedRandomSampler")
error_msg_1 = "type of weights element must be number"
with pytest.raises(TypeError, match=error_msg_1):
weights = ""
ds.WeightedRandomSampler(weights)
error_msg_2 = "type of weights element must be number"
with pytest.raises(TypeError, match=error_msg_2):
weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights)
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
with pytest.raises(RuntimeError, match=error_msg_3):
weights = []
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
with pytest.raises(RuntimeError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
with pytest.raises(RuntimeError, match=error_msg_5):
weights = [0, 0, 0, 0, 0]
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
def test_chained_sampler_with_random_sequential_repeat():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with Random and Sequential, with repeat.
Expectation: the data is processed successfully.
"""
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
# Create chained sampler, random and sequential
sampler = ds.RandomSampler()
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
# Create IMDBDataset with sampler
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(count=3)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 24
# Verify number of iterations
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 24
def test_chained_sampler_with_distribute_random_batch_then_repeat():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with Distributed and Random, with batch then repeat.
Expectation: the data is processed successfully.
"""
logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
# Create chained sampler, distributed and random
sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
child_sampler = ds.RandomSampler()
sampler.add_child(child_sampler)
# Create IMDBDataset with sampler
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
data1 = data1.batch(batch_size=5, drop_remainder=True)
data1 = data1.repeat(count=3)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 0
# Verify number of iterations
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
# Note: Each of the 4 shards has 44/4=11 samples
# Note: Number of iterations is (11/5 = 2) * 3 = 6
assert num_iter == 0
def test_chained_sampler_with_weighted_random_pk_sampler():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with WeightedRandom and PKSampler.
Expectation: the data is processed successfully.
"""
logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
# Create chained sampler, WeightedRandom and PKSampler
weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
sampler = ds.WeightedRandomSampler(weights=weights, num_samples=6)
child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
sampler.add_child(child_sampler)
# Create IMDBDataset with sampler
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
# Verify number of iterations
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
# Note: WeightedRandomSampler produces 12 samples
# Note: Child PKSampler produces 12 samples
assert num_iter == 6
def test_imdb_rename():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with rename.
Expectation: the data is processed successfully.
"""
logger.info("Test Case rename")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, num_samples=8)
data1 = data1.repeat(repeat_count)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
data1 = data1.rename(input_columns=["text"], output_columns="text2")
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text2"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_imdb_zip():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with zip.
Expectation: the data is processed successfully.
"""
logger.info("Test Case zip")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.IMDBDataset(DATA_DIR, num_samples=4)
data2 = ds.IMDBDataset(DATA_DIR, num_samples=4)
data1 = data1.repeat(repeat_count)
# rename dataset2 for no conflict
data2 = data2.rename(input_columns=["text", "label"], output_columns=["text1", "label1"])
data3 = ds.zip((data1, data2))
num_iter = 0
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"].item().decode("utf8")))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_imdb_exception():
"""
Feature: Test IMDB Dataset.
Description: read data from all file with exception.
Expectation: the data is processed successfully.
"""
logger.info("Test imdb exception")
def exception_func(item):
raise Exception("Error occur!")
def exception_func2(text, label):
raise Exception("Error occur!")
try:
data = ds.IMDBDataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.IMDBDataset(DATA_DIR)
data = data.map(operations=exception_func2, input_columns=["text", "label"],
output_columns=["text", "label", "label1"],
column_order=["text", "label", "label1"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
data_dir_invalid = "../data/dataset/IMDBDATASET"
try:
data = ds.IMDBDataset(data_dir_invalid)
for _ in data.__iter__():
pass
assert False
except ValueError as e:
assert "does not exist or is not a directory or permission denied" in str(e)
if __name__ == '__main__':
test_imdb_basic()
test_imdb_test()
test_imdb_train()
test_imdb_num_samples()
test_random_sampler()
test_distributed_sampler()
test_pk_sampler()
test_subset_random_sampler()
test_weighted_random_sampler()
test_weighted_random_sampler_exception()
test_chained_sampler_with_random_sequential_repeat()
test_chained_sampler_with_distribute_random_batch_then_repeat()
test_chained_sampler_with_weighted_random_pk_sampler()
test_imdb_num_shards()
test_imdb_shard_id()
test_imdb_no_shuffle()
test_imdb_true_shuffle()
test_imdb_rename()
test_imdb_zip()
test_imdb_exception()