forked from mindspore-Ecosystem/mindspore
!22086 [assistant][ops]New operator implementation, include ImdbDataset
Merge pull request !22086 from ZJUTER0126/ImdbDataset
This commit is contained in:
commit
9850307d02
|
@ -106,6 +106,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
|
||||
|
@ -1266,6 +1267,30 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
// Create logical representation of IMDBDataset.
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
// Create logical representation of IMDBDataset.
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
// Create logical representation of IMDBDataset.
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
IWSLT2016Dataset::IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<std::vector<char>> &language_pair,
|
||||
const std::vector<char> &valid_set, const std::vector<char> &test_set,
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
|
||||
|
@ -316,6 +317,16 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode",
|
||||
"to create an IMDBNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
||||
auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(imdb->ValidateParams());
|
||||
return imdb;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>(
|
||||
*m, "IWSLT2016Node", "to create an IWSLT2016Node")
|
||||
|
|
|
@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
fashion_mnist_op.cc
|
||||
flickr_op.cc
|
||||
image_folder_op.cc
|
||||
imdb_op.cc
|
||||
iwslt_op.cc
|
||||
io_block.cc
|
||||
kmnist_op.cc
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/imdb_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr int32_t kNumClasses = 2;
|
||||
|
||||
IMDBOp::IMDBOp(int32_t num_workers, const std::string &file_dir, int32_t queue_size, const std::string &usage,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(std::move(file_dir)),
|
||||
usage_(usage),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_ind_(0) {}
|
||||
|
||||
Status IMDBOp::PrepareData() {
|
||||
std::vector<std::string> usage_list;
|
||||
if (usage_ == "all") {
|
||||
usage_list.push_back("train");
|
||||
usage_list.push_back("test");
|
||||
} else {
|
||||
usage_list.push_back(usage_);
|
||||
}
|
||||
std::vector<std::string> label_list = {"pos", "neg"};
|
||||
// get abs path for folder_path_
|
||||
auto realpath = FileUtils::GetRealPath(folder_path_.data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, imdb dataset dir: " << folder_path_ << " does not exist.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, imdb dataset dir: " + folder_path_ + " does not exist.");
|
||||
}
|
||||
Path base_dir(realpath.value());
|
||||
for (auto usage : usage_list) {
|
||||
for (auto label : label_list) {
|
||||
Path dir = base_dir / usage / label;
|
||||
RETURN_IF_NOT_OK(GetDataByUsage(dir.ToString(), label));
|
||||
}
|
||||
}
|
||||
text_label_pairs_.shrink_to_fit();
|
||||
num_rows_ = text_label_pairs_.size();
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||
DatasetName() + " file path: " + folder_path_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (text, label) using 1 std::pair<std::string, int32_t>. 1 function call produces 1 TensorTow
|
||||
Status IMDBOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::pair<std::string, int32_t> pair_ptr = text_label_pairs_[row_id];
|
||||
std::shared_ptr<Tensor> text, label;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr.second, &label));
|
||||
RETURN_IF_NOT_OK(LoadFile(pair_ptr.first, &text));
|
||||
|
||||
(*trow) = TensorRow(row_id, {std::move(text), std::move(label)});
|
||||
trow->setPath({pair_ptr.first, std::string("")});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IMDBOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\n"
|
||||
<< DatasetName(true) << " directory: " << folder_path_ << "\nUsage: " << usage_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status IMDBOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || text_label_pairs_.empty()) {
|
||||
if (text_label_pairs_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid dataset dir, " + DatasetName(true) +
|
||||
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||
DatasetName() + " file path: " + folder_path_);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR], Map containing text-index pair is nullptr or has been set in other place, "
|
||||
"it must be empty before using GetClassIds.");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < text_label_pairs_.size(); ++i) {
|
||||
(*cls_ids)[text_label_pairs_[i].second].push_back(i);
|
||||
}
|
||||
for (auto &pair : (*cls_ids)) {
|
||||
pair.second.shrink_to_fit();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBOp::GetDataByUsage(const std::string &folder, const std::string &label) {
|
||||
Path dir_usage_label(folder);
|
||||
if (!dir_usage_label.Exists() || !dir_usage_label.IsDirectory()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid parameter, dataset dir may not exist or is not a directory: " + folder);
|
||||
}
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir_usage_label);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dir_itr != nullptr,
|
||||
"Invalid path, failed to open imdb dir: " + folder + ", permission denied.");
|
||||
std::map<std::string, int32_t> text_label_map;
|
||||
while (dir_itr->HasNext()) {
|
||||
Path file = dir_itr->Next();
|
||||
text_label_map[file.ToString()] = (label == "pos") ? 1 : 0;
|
||||
}
|
||||
for (auto item : text_label_map) {
|
||||
text_label_pairs_.emplace_back(std::make_pair(item.first, item.second));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBOp::CountRows(const std::string &path, const std::string &usage, int64_t *num_rows) {
|
||||
RETURN_UNEXPECTED_IF_NULL(num_rows);
|
||||
// get abs path for folder_path_
|
||||
auto abs_path = FileUtils::GetRealPath(path.data());
|
||||
if (!abs_path.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, imdb dataset dir: " << path << " does not exist.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, imdb dataset dir: " + path + " does not exist.");
|
||||
}
|
||||
Path data_dir(abs_path.value());
|
||||
std::vector<std::string> all_dirs_list = {"pos", "neg"};
|
||||
std::vector<std::string> usage_list;
|
||||
if (usage == "all") {
|
||||
usage_list.push_back("train");
|
||||
usage_list.push_back("test");
|
||||
} else {
|
||||
usage_list.push_back(usage);
|
||||
}
|
||||
int64_t row_cnt = 0;
|
||||
for (int32_t ind = 0; ind < usage_list.size(); ++ind) {
|
||||
Path texts_dir_usage_path = data_dir / usage_list[ind];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
texts_dir_usage_path.Exists() && texts_dir_usage_path.IsDirectory(),
|
||||
"Invalid path, dataset path may not exist or is not a directory: " + texts_dir_usage_path.ToString());
|
||||
|
||||
for (auto dir : all_dirs_list) {
|
||||
Path texts_dir_usage_dir_path((texts_dir_usage_path / dir).ToString());
|
||||
std::shared_ptr<Path::DirIterator> dir_iter = Path::DirIterator::OpenDirectory(&texts_dir_usage_dir_path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dir_iter != nullptr,
|
||||
"Invalid path, failed to open imdb dir: " + path + ", permission denied.");
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_iter);
|
||||
while (dir_iter->HasNext()) {
|
||||
row_cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
(*num_rows) = row_cnt;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBOp::ComputeColMap() {
|
||||
// Set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get number of classes
|
||||
Status IMDBOp::GetNumClasses(int64_t *num_classes) {
|
||||
RETURN_UNEXPECTED_IF_NULL(num_classes);
|
||||
*num_classes = kNumClasses;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBOp::LoadFile(const std::string &file, std::shared_ptr<Tensor> *out_row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_row);
|
||||
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
||||
}
|
||||
|
||||
std::string line;
|
||||
// IMDB just have a line for every txt.
|
||||
while (getline(handle, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto rc = LoadTensor(line, out_row);
|
||||
if (rc.IsError()) {
|
||||
handle.close();
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
handle.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBOp::LoadTensor(const std::string &line, std::shared_ptr<Tensor> *out_row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_row);
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
|
||||
*out_row = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Forward declares
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
class IMDBOp : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] int32_t num_workers - num of workers reading texts in parallel.
|
||||
/// \param[in] std::string dataset_dir - dir directory of IMDB dataset.
|
||||
/// \param[in] int32_t queue_size - connector queue size.
|
||||
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] DataSchema data_schema - the schema of each column in output data.
|
||||
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells Folder what to read.
|
||||
IMDBOp(int32_t num_workers, const std::string &dataset_dir, int32_t queue_size, const std::string &usage,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~IMDBOp() = default;
|
||||
|
||||
/// \brief Parse IMDB data.
|
||||
/// \return Status - The status code returned.
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
/// \param[in] map cls_ids - key label, val all ids for this class
|
||||
/// \return Status - The status code returned.
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief This function return the num_rows.
|
||||
/// \param[in] std::string path - dir directory of IMDB dataset.
|
||||
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[out] int64_t *num_rows - output arg that will hold the actual dataset size.
|
||||
/// \return Status - The status code returned.
|
||||
static Status CountRows(const std::string &path, const std::string &usage, int64_t *num_rows);
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "IMDBOp"; }
|
||||
|
||||
/// \brief Dataset name getter.
|
||||
/// \param[in] upper Whether to get upper name.
|
||||
/// \return Dataset name of the current Op.
|
||||
virtual std::string DatasetName(bool upper = false) const { return upper ? "IMDB" : "imdb"; }
|
||||
|
||||
/// \brief Base-class override for GetNumClasses
|
||||
/// \param[out] int64_t *num_classes - the number of classes
|
||||
/// \return Status - The status code returned.
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
private:
|
||||
/// \brief Load a tensor row according to a pair.
|
||||
/// \param[in] uint64_t row_id - row_id need to load.
|
||||
/// \param[out] TensorRow *row - text & task read into this tensor row.
|
||||
/// \return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
/// \brief Parses a single row and puts the data into a tensor table.
|
||||
/// \param[in] string line - the content of the row.
|
||||
/// \param[out] Tensor *out_row - the id of the row filled in the tensor table.
|
||||
/// \return Status - The status code returned.
|
||||
Status LoadTensor(const std::string &line, std::shared_ptr<Tensor> *out_row);
|
||||
|
||||
/// \brief Reads a text file and loads the data into Tensor.
|
||||
/// \param[in] string file - the file to read.
|
||||
/// \param[out] Tensor *out_row - the id of the row filled in the tensor table.
|
||||
/// \return Status - The status code returned.
|
||||
Status LoadFile(const std::string &file, std::shared_ptr<Tensor> *out_row);
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \param[in] string folder - the folder include files.
|
||||
/// \param[in] string label - the name of label.
|
||||
/// \return Status - The status code returned.
|
||||
Status GetDataByUsage(const std::string &folder, const std::string &label);
|
||||
|
||||
/// \brief function for computing the assignment of the column name map.
|
||||
/// \return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
std::string folder_path_; // directory of text folder
|
||||
std::string usage_;
|
||||
int64_t sampler_ind_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::vector<std::pair<std::string, int32_t>> text_label_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMDB_OP_H_
|
|
@ -96,6 +96,7 @@ constexpr char kFashionMnistNode[] = "FashionMnistDataset";
|
|||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
constexpr char kIMDBNode[] = "IMDBDataset";
|
||||
constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset";
|
||||
constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset";
|
||||
constexpr char kKMnistNode[] = "KMnistDataset";
|
||||
|
|
|
@ -22,6 +22,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
fashion_mnist_node.cc
|
||||
flickr_node.cc
|
||||
image_folder_node.cc
|
||||
imdb_node.cc
|
||||
iwslt2016_node.cc
|
||||
iwslt2017_node.cc
|
||||
kmnist_node.cc
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/imdb_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
IMDBNode::IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache = nullptr)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler), usage_(usage) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> IMDBNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void IMDBNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(path: " + dataset_dir_ + ", usage: " + usage_ + ")");
|
||||
}
|
||||
|
||||
Status IMDBNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("IMDBDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("IMDBDataset", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("IMDBDataset", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node_ops);
|
||||
// Do internal Schema generation.
|
||||
// This arg is exist in IMDBOp, but not externalized (in Python API).
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<IMDBOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status IMDBNode::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status IMDBNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_size);
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size, num_rows;
|
||||
RETURN_IF_NOT_OK(IMDBOp::CountRows(dataset_dir_, usage_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IMDBNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ds);
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<IMDBNode>(dataset_dir, usage, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \class IMDBNode
|
||||
/// \brief A Dataset derived class to represent IMDB dataset
|
||||
class IMDBNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] std::string dataset_dir - dir directory of IMDB dataset.
|
||||
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells Folder what to read.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~IMDBNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kIMDBNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] ostream out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[out] DatasetOp *node_ops A vector containing shared pointer to the Dataset Ops that this object will
|
||||
/// create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \param[out] int32_t *shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] DatasetSizeGetter size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] bool estimate This is only supported by some of the ops and it's used to speed up the process of
|
||||
/// getting dataset size at the expense of accuracy.
|
||||
/// \param[out] int64_t *dataset_size The size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] json *out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
/// \param[in] sampler Tells IMDBOp what to read.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMDB_NODE_H_
|
|
@ -2727,6 +2727,94 @@ inline std::shared_ptr<ImageFolderDataset> MS_API ImageFolder(const std::string
|
|||
MapStringToChar(class_indexing), cache);
|
||||
}
|
||||
|
||||
/// \class IMDBDataset
|
||||
/// \brief A source dataset for reading and parsing IMDB dataset.
|
||||
class MS_API IMDBDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of IMDBDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of IMDBDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of IMDBDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of IMDBDataset.
|
||||
~IMDBDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief A source dataset for reading and parsing IMDB dataset.
|
||||
/// \note The generated dataset has two columns ["text", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all"
|
||||
/// (Default="all").
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the IMDBDataset.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and MindData object */
|
||||
/// std::string dataset_path = "/path/to/imdb_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = IMDB(dataset_path, "all");
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In IMDB dataset, each data dictionary has keys "text" and "label" */
|
||||
/// auto text = row["text"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<IMDBDataset> MS_API
|
||||
IMDB(const std::string &dataset_dir, const std::string &usage = "all",
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief A source dataset for reading and parsing IMDB dataset.
|
||||
/// \note The generated dataset has two columns ["text", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the IMDBDataset.
|
||||
inline std::shared_ptr<IMDBDataset> MS_API IMDB(const std::string &dataset_dir, const std::string &usage,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief A source dataset for reading and parsing IMDB dataset.
|
||||
/// \note The generated dataset has two columns ["text", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the IMDBDataset.
|
||||
inline std::shared_ptr<IMDBDataset> MS_API IMDB(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class IWSLT2016Dataset.
|
||||
/// \brief A source dataset for reading and parsing IWSLT2016 dataset.
|
||||
class MS_API IWSLT2016Dataset : public Dataset {
|
||||
|
@ -4890,7 +4978,7 @@ class MS_API WikiTextDataset : public Dataset {
|
|||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In WikiText dataset, each dictionary has key "text" */
|
||||
/// auto text = row["image"];
|
||||
/// auto text = row["text"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<WikiTextDataset> MS_API WikiText(const std::string &dataset_dir,
|
||||
const std::string &usage = "all", int64_t num_samples = 0,
|
||||
|
|
|
@ -47,6 +47,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class FashionMnistDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class IMDBDataset;
|
||||
friend class KMnistDataset;
|
||||
friend class LJSpeechDataset;
|
||||
friend class ManifestDataset;
|
||||
|
|
|
@ -76,7 +76,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
|
||||
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset, \
|
||||
check_conll2000_dataset, check_amazon_review_dataset, check_semeion_dataset, check_caltech101_dataset, \
|
||||
check_caltech256_dataset, check_wiki_text_dataset
|
||||
check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -3670,6 +3670,147 @@ class ImageFolderDataset(MappableDataset):
|
|||
return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing)
|
||||
|
||||
|
||||
class IMDBDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Internet Movie Database (IMDb).
|
||||
|
||||
The generated dataset has two columns: :py:obj:`[text, label]`.
|
||||
The tensor of column :py:obj:`text` is of the string type.
|
||||
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be `train`, `test` or `all`
|
||||
(default=None, will read all samples).
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all samples).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None). When this argument is specified, `num_samples` reflects
|
||||
the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- The shape of the test column.
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> imdb_dataset_dir = "/path/to/imdb_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Read all samples (text files) in imdb_dataset_dir with 8 threads
|
||||
>>> dataset = ds.IMDBDataset(dataset_dir=imdb_dataset_dir, num_parallel_workers=8)
|
||||
>>>
|
||||
>>> # 2) Read train samples (text files).
|
||||
>>> dataset = ds.IMDBDataset(dataset_dir=imdb_dataset_dir, usage="train")
|
||||
|
||||
About IMDBDataset:
|
||||
|
||||
The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The data set
|
||||
was divided into 25 000 comments for training and 25 000 comments for testing, with both the training set and test
|
||||
set containing 50% positive and 50% negative comments. Train labels and test labels are all lists of 0 and 1, where
|
||||
0 stands for negative and 1 for positive.
|
||||
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── imdb_dataset_directory
|
||||
├── train
|
||||
│ ├── pos
|
||||
│ │ ├── 0_9.txt
|
||||
│ │ ├── 1_7.txt
|
||||
│ │ ├── ...
|
||||
│ ├── neg
|
||||
│ │ ├── 0_3.txt
|
||||
│ │ ├── 1_1.txt
|
||||
│ │ ├── ...
|
||||
├── test
|
||||
│ ├── pos
|
||||
│ │ ├── 0_10.txt
|
||||
│ │ ├── 1_10.txt
|
||||
│ │ ├── ...
|
||||
│ ├── neg
|
||||
│ │ ├── 0_2.txt
|
||||
│ │ ├── 1_3.txt
|
||||
│ │ ├── ...
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@InProceedings{maas-EtAl:2011:ACL-HLT2011,
|
||||
author = {Maas, Andrew L. and Daly, Raymond E. and Pham, Peter T. and Huang, Dan
|
||||
and Ng, Andrew Y. and Potts, Christopher},
|
||||
title = {Learning Word Vectors for Sentiment Analysis},
|
||||
booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics:
|
||||
Human Language Technologies},
|
||||
month = {June},
|
||||
year = {2011},
|
||||
address = {Portland, Oregon, USA},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
pages = {142--150},
|
||||
url = {http://www.aclweb.org/anthology/P11-1015}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_imdb_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
|
||||
num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.IMDBNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class IWSLT2016Dataset(SourceDataset, TextBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses IWSLT2016 datasets.
|
||||
|
|
|
@ -63,6 +63,35 @@ def check_imagefolderdataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_imdb_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original IMDBDataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_iwslt2016_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset)."""
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_fake_image_test.cc
|
||||
c_api_dataset_fashion_mnist_test.cc
|
||||
c_api_dataset_flickr_test.cc
|
||||
c_api_dataset_imdb_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_iwslt_test.cc
|
||||
c_api_dataset_kmnist_test.cc
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data and get all data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBBasic.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "all"; // 'train', 'test', 'all'
|
||||
|
||||
// Create a IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto text = row["text"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 8);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data and get train data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBTrain) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBTrain.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "train"; // 'train', 'test', 'all'
|
||||
|
||||
// Create a IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto text = row["text"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 4);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data and get test data.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBTest) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBTest.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "test"; // 'train', 'test', 'all'
|
||||
|
||||
// Create a IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto text = row["text"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 4);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data and test pipeline.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBBasicWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBBasicWithPipeline.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "all"; // 'train', 'test', 'all'
|
||||
|
||||
// Create two IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, usage);
|
||||
std::shared_ptr<Dataset> ds2 = IMDB(dataset_path, usage);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 3;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto text = row["text"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text.Shape() << ", Tensor label shape: " << label.Shape() << "\n";
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 40);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data with GetDatasetSize, GetColumnNames, GetBatchSize.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBGetters.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "all"; // 'train', 'test', 'all'
|
||||
|
||||
// Create a IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, usage);
|
||||
std::vector<std::string> column_names = {"text", "label"};
|
||||
|
||||
std::vector<DataType> types = ToDETypes(ds1->GetOutputTypes());
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "string");
|
||||
EXPECT_EQ(types[1].ToString(), "int32");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_EQ(ds1->GetDatasetSize(), 8);
|
||||
EXPECT_EQ(ds1->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds1->GetBatchSize(), 1);
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data with errors.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "all"; // 'train', 'test', 'all'
|
||||
|
||||
// Create a IMDB Dataset with non-existing dataset dir
|
||||
std::shared_ptr<Dataset> ds0 = IMDB("NotExistDir", usage);
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid IMDB input
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
|
||||
// Create a IMDB Dataset with err usage
|
||||
std::shared_ptr<Dataset> ds1 = IMDB(dataset_path, "val");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid IMDB input
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Test IMDB Dataset.
|
||||
/// Description: read IMDB data with Null SamplerError.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestIMDBWithNullSamplerError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIMDBWithNullSamplerError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testIMDBDataset";
|
||||
std::string usage = "all";
|
||||
|
||||
// Create a IMDB Dataset
|
||||
std::shared_ptr<Dataset> ds = IMDB(dataset_path, usage, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid IMDB input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
test_neg_0.txt
|
|
@ -0,0 +1 @@
|
|||
test_neg_1.txt
|
|
@ -0,0 +1 @@
|
|||
test_pos_0.txt
|
|
@ -0,0 +1 @@
|
|||
test_pos_1.txt
|
|
@ -0,0 +1 @@
|
|||
train_neg_0.txt
|
|
@ -0,0 +1 @@
|
|||
train_neg_1.txt
|
|
@ -0,0 +1 @@
|
|||
train_pos_0.txt
|
|
@ -0,0 +1 @@
|
|||
train_pos_1.txt
|
|
@ -0,0 +1 @@
|
|||
train_unsup_0.txt
|
|
@ -0,0 +1,732 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testIMDBDataset"
|
||||
|
||||
|
||||
def test_imdb_basic():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case basic")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 8
|
||||
|
||||
content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt",
|
||||
"test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
|
||||
label = [1, 1, 0, 0, 1, 1, 0, 0]
|
||||
|
||||
num_iter = 0
|
||||
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
# each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
strs = item["text"].item().decode("utf8")
|
||||
logger.info("text is {}".format(strs))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert strs == content[index]
|
||||
assert label[index] == int(item["label"])
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_imdb_test():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from test file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case test")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
usage = "test"
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 4
|
||||
|
||||
content = ["test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
|
||||
label = [1, 1, 0, 0]
|
||||
|
||||
num_iter = 0
|
||||
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
# each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
strs = item["text"].item().decode("utf8")
|
||||
logger.info("text is {}".format(strs))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert strs == content[index]
|
||||
assert label[index] == int(item["label"])
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_imdb_train():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from train file.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case train")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
usage = "train"
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 4
|
||||
|
||||
content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt"]
|
||||
label = [1, 1, 0, 0]
|
||||
|
||||
num_iter = 0
|
||||
for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
# each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
strs = item["text"].item().decode("utf8")
|
||||
logger.info("text is {}".format(strs))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert strs == content[index]
|
||||
assert label[index] == int(item["label"])
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_imdb_num_samples():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with num_samples=10 and num_parallel_workers=2.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case numSamples")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_samples=6, num_parallel_workers=2)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 6
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 6
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_imdb_num_shards():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with num_shards=2 and shard_id=1.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case numShards")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=1)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 4
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_imdb_shard_id():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with num_shards=4 and shard_id=1.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case withShardID")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=0)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 4
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_imdb_no_shuffle():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with shuffle=False.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case noShuffle")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 8
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_imdb_true_shuffle():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with shuffle=True.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case extraShuffle")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, shuffle=True)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 16
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 16
|
||||
|
||||
|
||||
def test_random_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with sampler=ds.RandomSampler().
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case RandomSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.RandomSampler()
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 8
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_distributed_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with sampler=ds.DistributedSampler().
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case DistributedSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.DistributedSampler(4, 1)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 2
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_pk_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with sampler=ds.PKSampler().
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case PKSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.PKSampler(3)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 6
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_subset_random_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with sampler=ds.SubsetRandomSampler().
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case SubsetRandomSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
indices = [0, 3, 1, 2, 5, 4]
|
||||
sampler = ds.SubsetRandomSampler(indices)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 6
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_weighted_random_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with sampler=ds.WeightedRandomSampler().
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case WeightedRandomSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
|
||||
sampler = ds.WeightedRandomSampler(weights, 6)
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 6
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_weighted_random_sampler_exception():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with random sampler exception.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test error cases for WeightedRandomSampler")
|
||||
error_msg_1 = "type of weights element must be number"
|
||||
with pytest.raises(TypeError, match=error_msg_1):
|
||||
weights = ""
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_2 = "type of weights element must be number"
|
||||
with pytest.raises(TypeError, match=error_msg_2):
|
||||
weights = (0.9, 0.8, 1.1)
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
weights = []
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
|
||||
with pytest.raises(RuntimeError, match=error_msg_5):
|
||||
weights = [0, 0, 0, 0, 0]
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
|
||||
def test_chained_sampler_with_random_sequential_repeat():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with Random and Sequential, with repeat.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
|
||||
|
||||
# Create chained sampler, random and sequential
|
||||
sampler = ds.RandomSampler()
|
||||
child_sampler = ds.SequentialSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
# Create IMDBDataset with sampler
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
|
||||
data1 = data1.repeat(count=3)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 24
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 24
|
||||
|
||||
|
||||
def test_chained_sampler_with_distribute_random_batch_then_repeat():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with Distributed and Random, with batch then repeat.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
|
||||
|
||||
# Create chained sampler, distributed and random
|
||||
sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
|
||||
child_sampler = ds.RandomSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
# Create IMDBDataset with sampler
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
|
||||
data1 = data1.batch(batch_size=5, drop_remainder=True)
|
||||
data1 = data1.repeat(count=3)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 0
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
# Note: Each of the 4 shards has 44/4=11 samples
|
||||
# Note: Number of iterations is (11/5 = 2) * 3 = 6
|
||||
assert num_iter == 0
|
||||
|
||||
|
||||
def test_chained_sampler_with_weighted_random_pk_sampler():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with WeightedRandom and PKSampler.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
|
||||
|
||||
# Create chained sampler, WeightedRandom and PKSampler
|
||||
weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
|
||||
sampler = ds.WeightedRandomSampler(weights=weights, num_samples=6)
|
||||
child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
|
||||
sampler.add_child(child_sampler)
|
||||
# Create IMDBDataset with sampler
|
||||
data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 6
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
# Note: WeightedRandomSampler produces 12 samples
|
||||
# Note: Child PKSampler produces 12 samples
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_imdb_rename():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with rename.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case rename")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_samples=8)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
data1 = data1.rename(input_columns=["text"], output_columns="text2")
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text2"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_imdb_zip():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with zip.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test Case zip")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset(DATA_DIR, num_samples=4)
|
||||
data2 = ds.IMDBDataset(DATA_DIR, num_samples=4)
|
||||
|
||||
data1 = data1.repeat(repeat_count)
|
||||
# rename dataset2 for no conflict
|
||||
data2 = data2.rename(input_columns=["text", "label"], output_columns=["text1", "label1"])
|
||||
data3 = ds.zip((data1, data2))
|
||||
|
||||
num_iter = 0
|
||||
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"].item().decode("utf8")))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_imdb_exception():
|
||||
"""
|
||||
Feature: Test IMDB Dataset.
|
||||
Description: read data from all file with exception.
|
||||
Expectation: the data is processed successfully.
|
||||
"""
|
||||
logger.info("Test imdb exception")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
def exception_func2(text, label):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.IMDBDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.IMDBDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func2, input_columns=["text", "label"],
|
||||
output_columns=["text", "label", "label1"],
|
||||
column_order=["text", "label", "label1"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
data_dir_invalid = "../data/dataset/IMDBDATASET"
|
||||
try:
|
||||
data = ds.IMDBDataset(data_dir_invalid)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
assert "does not exist or is not a directory or permission denied" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_imdb_basic()
|
||||
test_imdb_test()
|
||||
test_imdb_train()
|
||||
test_imdb_num_samples()
|
||||
test_random_sampler()
|
||||
test_distributed_sampler()
|
||||
test_pk_sampler()
|
||||
test_subset_random_sampler()
|
||||
test_weighted_random_sampler()
|
||||
test_weighted_random_sampler_exception()
|
||||
test_chained_sampler_with_random_sequential_repeat()
|
||||
test_chained_sampler_with_distribute_random_batch_then_repeat()
|
||||
test_chained_sampler_with_weighted_random_pk_sampler()
|
||||
test_imdb_num_shards()
|
||||
test_imdb_shard_id()
|
||||
test_imdb_no_shuffle()
|
||||
test_imdb_true_shuffle()
|
||||
test_imdb_rename()
|
||||
test_imdb_zip()
|
||||
test_imdb_exception()
|
Loading…
Reference in New Issue