[assistant][ops]New operator implementation, include RenderedSST2Dataset

This commit is contained in:
uccInf 2022-09-01 20:20:28 +08:00
parent c05a7d37bd
commit 543d82670f
32 changed files with 1683 additions and 6 deletions

View File

@ -0,0 +1,118 @@
mindspore.dataset.RenderedSST2Dataset
=====================================
.. py:class:: mindspore.dataset.RenderedSST2Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
读取和解析RenderedSST2数据集的源文件构建数据集。
生成的数据集有两列 `[image, label]``image` 列的数据类型为uint8。`label` 列的数据类型为uint32。
参数:
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
- **usage** (str, 可选) - 指定数据集的子集,可取值为'train'、'val'、'test'或'all'。默认值None读取全部样本图片。
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数可以小于数据集总数。默认值None读取全部样本图片。
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值None使用mindspore.dataset.config中配置的线程数。
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值None下表中会展示不同参数配置的预期行为。
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值False不解码。
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值None下表中会展示不同配置的预期行为。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值None。指定此参数后 `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值None不使用缓存。
异常:
- **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。
- **RuntimeError** - 同时指定了 `sampler``shuffle` 参数。
- **RuntimeError** - 同时指定了 `sampler``num_shards` 参数或同时指定了 `sampler``shard_id` 参数。
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError** - `usage` 参数取值不为'train'、'val'、'test'或'all'。
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
- **ValueError** - `shard_id` 参数值错误小于0或者大于等于 `num_shards`
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
.. list-table:: 配置 `sampler``shuffle` 的不同组合得到的预期排序结果
:widths: 25 25 50
:header-rows: 1
* - 参数 `sampler`
- 参数 `shuffle`
- 预期数据顺序
* - None
- None
- 随机排列
* - None
- True
- 随机排列
* - None
- False
- 顺序排列
* - `sampler` 实例
- None
- 由 `sampler` 行为定义的顺序
* - `sampler` 实例
- True
- 不允许
* - `sampler` 实例
- False
- 不允许
**关于RenderedSST2数据集**
Rendered SST2是一个图像分类数据集它是由SST2数据集中的数据生成的。数据集被分割成三份每一份包含有两类positive和negative
在train这一份下共有6920张图像3610张positive3310张negative在validation这一份下共有872张图像444张positive428张negative
在test这一份下共有1821张图像909张positive912张negative
以下为原始RenderedSST2数据集的结构您可以将数据集文件解压得到如下的文件结构并通过MindSpore的API进行读取。
.. code-block::
.
└── rendered_sst2_dataset_directory
├── train
│ ├── negative
│ │ ├── 0001.jpg
│ │ ├── 0002.jpg
│ │ ...
│ └── positive
│ ├── 0001.jpg
│ ├── 0002.jpg
│ ...
├── test
│ ├── negative
│ │ ├── 0001.jpg
│ │ ├── 0002.jpg
│ │ ...
│ └── positive
│ ├── 0001.jpg
│ ├── 0002.jpg
│ ...
└── valid
├── negative
│ ├── 0001.jpg
│ ├── 0002.jpg
│ ...
└── positive
├── 0001.jpg
├── 0002.jpg
...
**引用:**
.. code-block::
@inproceedings{socher-etal-2013-recursive,
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
Christopher D. and Ng, Andrew and Potts, Christopher},
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
month = oct,
year = {2013},
address = {Seattle, Washington, USA},
publisher = {Association for Computational Linguistics},
url = {https://www.aclweb.org/anthology/D13-1170},
pages = {1631--1642},
}
.. include:: mindspore.dataset.api_list_vision.rst

View File

@ -120,6 +120,7 @@ mindspore.dataset
mindspore.dataset.PhotoTourDataset
mindspore.dataset.Places365Dataset
mindspore.dataset.QMnistDataset
mindspore.dataset.RenderedSST2Dataset
mindspore.dataset.SBDataset
mindspore.dataset.SBUDataset
mindspore.dataset.SemeionDataset

View File

@ -32,6 +32,7 @@ Vision
mindspore.dataset.PhotoTourDataset
mindspore.dataset.Places365Dataset
mindspore.dataset.QMnistDataset
mindspore.dataset.RenderedSST2Dataset
mindspore.dataset.SBDataset
mindspore.dataset.SBUDataset
mindspore.dataset.SemeionDataset

View File

@ -116,6 +116,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h"
@ -2005,6 +2006,33 @@ RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vecto
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds =
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds =
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds =
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;

View File

@ -56,6 +56,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
@ -642,6 +643,18 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(RenderedSST2Node, 2, ([](const py::module *m) {
(void)py::class_<RenderedSST2Node, DatasetNode, std::shared_ptr<RenderedSST2Node>>(
*m, "RenderedSST2Node", "to create a RenderedSST2Node")
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
const py::handle &sampler) {
auto rendered_sst2 =
std::make_shared<RenderedSST2Node>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(rendered_sst2->ValidateParams());
return rendered_sst2;
}));
}));
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
"to create an SBUNode")

View File

@ -44,6 +44,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
places365_op.cc
qmnist_op.cc
random_data_op.cc
rendered_sst2_op.cc
sbu_op.cc
semeion_op.cc
sogou_news_op.cc

View File

@ -0,0 +1,373 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/rendered_sst2_op.h"
#include <fstream>
#include <unordered_set>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
RenderedSST2Op::RenderedSST2Op(int32_t num_wkrs, const std::string &file_dir, const std::string &usage,
int32_t queue_size, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, uint32_t> &map, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_wkrs, queue_size, std::move(sampler)),
folder_path_(file_dir),
usage_(usage),
decode_(do_decode),
extensions_(exts),
class_index_(map),
data_schema_(std::move(data_schema)),
sampler_ind_(0) {
folder_path_queue_ = std::make_unique<Queue<std::string>>(num_wkrs * queue_size);
folder_classId_queue_ = std::make_unique<Queue<uint32_t>>(num_wkrs * queue_size);
image_name_queue_ = std::make_unique<Queue<FolderImagesPair>>(num_wkrs * queue_size);
}
// Master thread that pulls the prescan worker's results.
// Keep collecting results until all prescan workers quit
// Then consolidate 2 level shuffles together into 1 giant vector
// calculate numRows then return
Status RenderedSST2Op::PrepareData() {
std::vector<FolderImagesPair> v;
int64_t cnt = 0;
while (cnt != num_workers_) { // count number of end signals
FolderImagesPair p;
RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p));
if (p == nullptr) {
cnt++;
} else {
v.push_back(p);
}
}
std::sort(v.begin(), v.end(),
[](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; });
// following loop puts the 2 level of shuffles together into 1 vector
for (size_t ind = 0; ind < v.size(); ++ind) {
while (!v[ind]->second.empty()) {
image_label_pairs_.push_back(v[ind]->second.front());
image_prefix_.push_back(v[ind]->first);
v[ind]->second.pop();
}
}
image_label_pairs_.shrink_to_fit();
num_rows_ = image_label_pairs_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
DatasetName() + " file path: " + folder_path_);
}
// free memory of two queues used for pre-scan
folder_path_queue_->Reset();
folder_classId_queue_->Reset();
image_name_queue_->Reset();
return Status::OK();
}
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
Status RenderedSST2Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
std::shared_ptr<Tensor> image, label;
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(image_prefix_[row_id] + pair_ptr->first, &image));
if (decode_) {
Status rc = Decode(image, &image);
if (rc.IsError()) {
std::string err = "Invalid image, " + folder_path_ + (pair_ptr->first) +
" decode failed, the image is broken or permission denied.";
RETURN_STATUS_UNEXPECTED(err);
}
}
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
return Status::OK();
}
void RenderedSST2Op::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\n"
<< DatasetName(true) << " directory: " << folder_path_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
// Derived from RandomAccessOp
Status RenderedSST2Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
if (image_label_pairs_.empty()) {
RETURN_STATUS_UNEXPECTED("Invalid dataset_dir, " + DatasetName(true) +
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
DatasetName() + " file path: " + folder_path_);
} else {
RETURN_STATUS_UNEXPECTED(
"[Internal ERROR], Map containing image-index pair is nullptr or has been set in other place,"
"it must be empty before using GetClassIds.");
}
}
for (size_t i = 0; i < image_label_pairs_.size(); ++i) {
(*cls_ids)[image_label_pairs_[i]->second].push_back(i);
}
for (auto &pair : (*cls_ids)) {
pair.second.shrink_to_fit();
}
return Status::OK();
}
// Worker Entry for pre-scanning all the folders and do the 1st level shuffle
// Worker pull a file path from folder_path_queue_ (which is a Queue), walks all the images under that folderpath
// After walking is complete, sort all the file names (relative path to all png files under the same directory )
// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree)
// Add the sorted filenames in to a queue. The make a pair (folderpath, queue<filenames>*),
// folderpath is used for 2nd level sorting.
// FYI: 1st level sorting: sort all images under the same directory.
// FYI: 2nd level sorting: sort all folder names
// push this pair to image_name_queue (which is again a Queue)
Status RenderedSST2Op::PrescanWorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::string folder_;
uint32_t current_class_id;
RETURN_IF_NOT_OK(folder_path_queue_->PopFront(&folder_));
RETURN_IF_NOT_OK(folder_classId_queue_->PopFront(&current_class_id));
while (!folder_.empty()) {
Path folder(folder_);
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
if (!folder.Exists() || dirItr == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid dataset_dir, " + folder_ + " does not exist or permission denied.");
}
auto offset_ = folder_.size();
std::set<std::string> imgs; // use this for ordering
while (dirItr->HasNext()) {
Path file = dirItr->Next();
if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) {
(void)imgs.insert(file.ToString().substr(offset_));
} else {
MS_LOG(WARNING) << DatasetName(true) << " operator unsupported file found: " << file.ToString()
<< ", extension: " << file.Extension() << ".";
}
}
FolderImagesPair p = std::make_shared<std::pair<std::string, std::queue<ImageLabelPair>>>();
p->first = folder_;
for (const std::string &img : imgs) {
p->second.push(std::make_shared<std::pair<std::string, uint32_t>>(img, current_class_id));
}
RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p));
RETURN_IF_NOT_OK(folder_path_queue_->PopFront(&folder_));
RETURN_IF_NOT_OK(folder_classId_queue_->PopFront(&current_class_id));
}
RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal
return Status::OK();
}
// This helper function walks all folder_paths, and send each folderpath to folder_path_queue_
Status RenderedSST2Op::WalkFolder(Path *dir) {
RETURN_UNEXPECTED_IF_NULL(dir);
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir);
RETURN_UNEXPECTED_IF_NULL(dir_itr);
auto offset_ = dir->ToString().size();
std::string current_class;
while (dir_itr->HasNext()) {
Path subdir = dir_itr->Next();
if (subdir.IsDirectory()) {
RETURN_IF_NOT_OK(folder_path_queue_->EmplaceBack(subdir.ToString()));
current_class = subdir.ToString().substr(offset_ + 1);
if (class_index_.find(current_class) == class_index_.end()) {
class_index_[current_class] = class_index_.size();
}
RETURN_IF_NOT_OK(folder_classId_queue_->EmplaceBack(class_index_[current_class]));
}
}
return Status::OK();
}
// A thread that calls WalkFolder
Status RenderedSST2Op::StartAsyncWalk() {
TaskManager::FindMe()->Post();
Path dir(folder_path_);
if (!dir.Exists() || !dir.IsDirectory()) {
RETURN_STATUS_UNEXPECTED("Invalid path, " + folder_path_ + " may not exist or is not a directory.");
}
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
RETURN_UNEXPECTED_IF_NULL(dir_itr);
auto offset_ = folder_path_.length();
while (dir_itr->HasNext()) {
Path subdir = dir_itr->Next();
if (subdir.IsDirectory()) {
std::string name = subdir.ToString().substr(offset_ + 1);
if (usage_ == name) {
RETURN_IF_NOT_OK(WalkFolder(&subdir));
} else if (usage_ == "val" && name == "valid") {
RETURN_IF_NOT_OK(WalkFolder(&subdir));
} else if (usage_ == "all" && (name == "train" || name == "test" || name == "valid")) {
RETURN_IF_NOT_OK(WalkFolder(&subdir));
}
}
}
// send out num_workers_ end signal to folder_path_queue_, 1 for each worker.
// Upon receiving end Signal, worker quits and set another end Signal to image_name_queue.
for (int32_t ind = 0; ind < num_workers_; ++ind) {
RETURN_IF_NOT_OK(folder_path_queue_->EmplaceBack("")); // end signal
RETURN_IF_NOT_OK(folder_classId_queue_->EmplaceBack(0));
}
return Status::OK();
}
Status RenderedSST2Op::RegisterAndLaunchThreads() {
RETURN_IF_NOT_OK(ParallelOp::RegisterAndLaunchThreads());
RETURN_IF_NOT_OK(folder_path_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(folder_classId_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks()));
// The following code launch 3 threads group
// 1) A thread that walks all folders and push the folder names to a util:Queue folder_path_queue_.
// 2) Workers that pull foldername from folder_path_queue_, walk it and return the sorted images to image_name_queue
// 3) Launch main workers that load TensorRows by reading all images
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(Name() + "::WalkDir",
std::bind(&RenderedSST2Op::StartAsyncWalk, this), nullptr, id()));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
std::bind(&RenderedSST2Op::PrescanWorkerEntry, this, std::placeholders::_1),
Name() + "::PrescanWorkerEntry", id()));
return Status::OK();
}
Status RenderedSST2Op::WalkFolderForCountRows(Path *dir, std::queue<std::string> *folder_paths,
std::map<std::string, uint32_t> *class_index) {
RETURN_UNEXPECTED_IF_NULL(dir);
RETURN_UNEXPECTED_IF_NULL(folder_paths);
RETURN_UNEXPECTED_IF_NULL(class_index);
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir);
RETURN_UNEXPECTED_IF_NULL(dir_itr);
std::string current_class;
auto offset_ = dir->ToString().size();
while (dir_itr->HasNext()) {
Path subdir = dir_itr->Next();
if (subdir.IsDirectory()) {
folder_paths->push(subdir.ToString());
current_class = subdir.ToString().substr(offset_ + 1);
if (class_index->find(current_class) == class_index->end()) {
(*class_index)[current_class] = class_index->size();
}
}
}
return Status::OK();
}
Status RenderedSST2Op::CountRows(std::queue<std::string> *folder_paths, int64_t *num_rows,
const std::set<std::string> &exts) {
RETURN_UNEXPECTED_IF_NULL(folder_paths);
RETURN_UNEXPECTED_IF_NULL(num_rows);
int64_t row_cnt = 0;
while (!folder_paths->empty()) {
Path subdir(folder_paths->front());
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&subdir);
if (!subdir.Exists() || dir_itr == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid subdirectory, RenderedSST2 Dataset subdirectory: " + subdir.ToString() +
" does not exist or permission denied");
}
while (dir_itr->HasNext()) {
if (exts.empty() || exts.find(dir_itr->Next().Extension()) != exts.end()) {
++row_cnt;
}
}
folder_paths->pop();
}
(*num_rows) = row_cnt;
return Status::OK();
}
Status RenderedSST2Op::CountRowsAndClasses(const std::string &path, const std::string &usage,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes) {
Path dir(path);
std::string err_msg = "";
err_msg += (!dir.Exists() || !dir.IsDirectory())
? "Invalid dataset_dir, " + path + " does not exist or the path is not a directory. "
: "";
err_msg += (num_classes == nullptr && num_rows == nullptr) ? "[Internal ERROR] num_class and num_rows are null." : "";
if (!err_msg.empty()) {
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::queue<std::string> folder_paths;
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
std::map<std::string, uint32_t> class_index;
auto offset_ = path.size();
RETURN_UNEXPECTED_IF_NULL(dir_itr);
while (dir_itr->HasNext()) {
Path subdir = dir_itr->Next();
if (subdir.IsDirectory()) {
std::string name = subdir.ToString().substr(offset_ + 1);
name = name == "valid" ? "val" : name;
if (usage == name) {
RETURN_IF_NOT_OK(WalkFolderForCountRows(&subdir, &folder_paths, &class_index));
} else if (usage == "all" && (name == "train" || name == "test" || name == "val")) {
RETURN_IF_NOT_OK(WalkFolderForCountRows(&subdir, &folder_paths, &class_index));
}
}
}
if (num_classes != nullptr) {
*num_classes = class_index.size();
}
// return here if only num_class is needed
RETURN_OK_IF_TRUE(num_rows == nullptr);
RETURN_IF_NOT_OK(CountRows(&folder_paths, num_rows, exts));
return Status::OK();
}
Status RenderedSST2Op::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
// Get number of classes
Status RenderedSST2Op::GetNumClasses(int64_t *num_classes) {
RETURN_UNEXPECTED_IF_NULL(num_classes);
if (num_classes_ > 0) {
*num_classes = num_classes_;
return Status::OK();
}
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, usage_, extensions_, nullptr, num_classes));
num_classes_ = *num_classes;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,172 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_
#include <algorithm>
#include <deque>
#include <map>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// Forward declares
template <typename T>
class Queue;
using ImageLabelPair = std::shared_ptr<std::pair<std::string, uint32_t>>;
using FolderImagesPair = std::shared_ptr<std::pair<std::string, std::queue<ImageLabelPair>>>;
class RenderedSST2Op : public MappableLeafOp {
public:
/// Constructor.
/// @param int32_t num_wkrs - Num of workers reading images in parallel.
/// @param const std::string &file_dir - Directory of RenderedSST2Dataset.
/// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'val' or 'all'.
/// @param int32_t queue_size - Connector queue size.
/// @param bool do_decode - Decode the images after reading.
/// @param std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the dir.
/// @param std::map<std::string, int32_t> &map- Map of class name and class id.
/// @param std::unique_ptr<dataschema> data_schema - Schema of data.
/// @param std::shared_ptr<SamplerRT> sampler - Sampler tells RenderedSST2Op what to read.
RenderedSST2Op(int32_t num_wkrs, const std::string &file_dir, const std::string &usage, int32_t queue_size,
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, uint32_t> &map,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// Destructor.
~RenderedSST2Op() override = default;
/// Initialize RenderedSST2Op related var, calls the function to walk all files.
/// @return Status The status code returned.
Status PrepareData() override;
/// Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector.
/// @param int32_t worker_id - Id of each worker.
/// @return Status The status code returned.
Status PrescanWorkerEntry(int32_t worker_id);
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
/// @param (std::map<int32_t, std::vector<int64_t >> * cls_ids - Key label, val all ids for this class.
/// @return Status The status code returned.
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
/// A print method typically used for debugging.
/// @param std::ostream &out - Out stream.
/// @param bool show_all - Whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
/// This function is a hack! It is to return the num_class and num_rows. The result
/// returned by this function may not be consistent with what RenderedSST2Op is going to return
/// user this at your own risk!
/// @param const std::string &path - Directory of RenderedSST2Dataset.
/// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'valid' or 'all'.
/// @param const std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the
/// dir.
/// @param int64_t *num_rows - The number of rows.
/// @param int64_t *num_classes - The number of classes.
/// @return Status of the function.
static Status CountRowsAndClasses(const std::string &path, const std::string &usage,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes);
/// This help function is used to count the num_rows.
/// @param std::queue<std::string> *folder_paths - A queue contains all the image folder paths.
/// @param int64_t *num_rows - The number of rows.
/// @param const std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the
/// dir.
/// @return Status of the function.
static Status CountRows(std::queue<std::string> *folder_paths, int64_t *num_rows, const std::set<std::string> &exts);
/// Op name getter.
/// @return Name of the current Op.
std::string Name() const override { return "RenderedSST2Op"; }
/// DatasetName name getter.
/// @return DatasetName of the current Op.
virtual std::string DatasetName(bool upper = false) const { return upper ? "RenderedSST2" : "rendered sst2"; }
/// Base-class override for GetNumClasses.
/// @param num_classes - The number of classes.
/// @return Status of the function.
Status GetNumClasses(int64_t *num_classes) override;
protected:
/// Load a tensor row according to a pair.
/// @param row_id_type row_id - Id for this tensor row.
/// @param TensorRow *row - Image & label read into this tensor row.
/// @return Status The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
/// @param Path * dir - Dir to walk all folders.
/// @return Status The status code returned.
virtual Status WalkFolder(Path *dir);
/// @param Path * dir - Dir to walk all images.
/// @param std::queue<std::string> *folder_paths - A queue contains all the image folder paths.
/// @param std::map<std::string, int32_t> *class_index - A map records the class and the class's Id.
/// @return Status The status code returned.
static Status WalkFolderForCountRows(Path *dir, std::queue<std::string> *folder_paths,
std::map<std::string, uint32_t> *class_index);
/// start walking of all dirs.
/// @return Status The status code returned.
Status StartAsyncWalk();
/// Called first when function is called.
/// @return Status The status code returned.
Status RegisterAndLaunchThreads() override;
/// Private function for computing the assignment of the column name map.
/// @return Status The status code returned.
Status ComputeColMap() override;
std::string folder_path_; // directory of image folder
std::string usage_;
bool recursive_;
bool decode_;
std::set<std::string> extensions_; // extensions allowed
std::map<std::string, uint32_t> class_index_;
std::unique_ptr<DataSchema> data_schema_;
int64_t sampler_ind_;
std::vector<ImageLabelPair> image_label_pairs_;
std::vector<std::string> image_prefix_;
std::unique_ptr<Queue<std::string>> folder_path_queue_;
std::unique_ptr<Queue<uint32_t>> folder_classId_queue_; // the class Id of the images under the folder
std::unique_ptr<Queue<FolderImagesPair>> image_name_queue_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_

View File

@ -118,6 +118,7 @@ constexpr char kPhotoTourNode[] = "PhotoTourDataset";
constexpr char kPlaces365Node[] = "Places365Dataset";
constexpr char kQMnistNode[] = "QMnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kRenderedSST2Node[] = "RenderedSST2Dataset";
constexpr char kSBUNode[] = "SBUDataset";
constexpr char kSemeionNode[] = "SemeionDataset";
constexpr char kSogouNewsNode[] = "SogouNewsDataset";

View File

@ -45,6 +45,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
places365_node.cc
qmnist_node.cc
random_node.cc
rendered_sst2_node.cc
sbu_node.cc
semeion_node.cc
sogou_news_node.cc

View File

@ -0,0 +1,147 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/rendered_sst2_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const std::set<std::string> kExts = {".png"};
RenderedSST2Node::RenderedSST2Node(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> RenderedSST2Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<RenderedSST2Node>(dataset_dir_, usage_, decode_, sampler, cache_);
return node;
}
void RenderedSST2Node::Print(std::ostream &out) const {
out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
}
Status RenderedSST2Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("RenderedSST2Node", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("RenderedSST2Node", usage_, {"val", "train", "all", "test"}));
RETURN_IF_NOT_OK(ValidateDatasetSampler("RenderedSST2Node", sampler_));
return Status::OK();
}
Status RenderedSST2Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
// This arg exists in RenderedSST2Op, but is not externalized (in Python API).
RETURN_UNEXPECTED_IF_NULL(node_ops);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
const std::map<std::string, uint32_t> kClassIndex = {};
auto op = std::make_shared<RenderedSST2Op>(num_workers_, dataset_dir_, usage_, connector_que_size_, decode_, kExts,
kClassIndex, std::move(schema), std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node.
Status RenderedSST2Node::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size.
Status RenderedSST2Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(RenderedSST2Op::CountRowsAndClasses(dataset_dir_, usage_, kExts, &num_rows, nullptr));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
RETURN_UNEXPECTED_IF_NULL(size_getter);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status RenderedSST2Node::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status RenderedSST2Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kRenderedSST2Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kRenderedSST2Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kRenderedSST2Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kRenderedSST2Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kRenderedSST2Node));
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<RenderedSST2Node>(dataset_dir, usage, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,117 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class RenderedSST2Node.
/// \brief A Dataset derived class to represent RenderedSST2 dataset.
class RenderedSST2Node : public MappableSourceNode {
public:
/// \brief Constructor.
RenderedSST2Node(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~RenderedSST2Node() override = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kRenderedSST2Node; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[out] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
/// \return The dataset_dir
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Getter functions.
/// \return The usage
const std::string &Usage() const { return usage_; }
/// \brief Getter functions.
/// \return The Decode.
bool Decode() const { return decode_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json.
/// \param[in] json_obj The JSON object to be deserialized.
/// \param[out] ds Deserialized dataset.
/// \return Status The status code returned.
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_
#include <memory>
#include <string>

View File

@ -4882,6 +4882,99 @@ std::shared_ptr<RandomDataDataset> DATASET_API RandomData(const int32_t &total_r
return ds;
}
/// \class RenderedSST2Dataset
/// \brief A source dataset for reading and parsing RenderedSST2 dataset.
class DATASET_API RenderedSST2Dataset : public Dataset {
public:
/// \brief Constructor of RenderedSST2Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of RenderedSST2Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of RenderedSST2Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of RenderedSST2Dataset.
~RenderedSST2Dataset() override = default;
};
/// \brief Function to create a RenderedSST2Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all". Default: "all".
/// \param[in] decode Decode the images after reading. Default: false.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset. Default: RandomSampler().
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the RenderedSST2Dataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string dataset_path = "/path/to/RenderedSST2_dataset_directory";
/// std::shared_ptr<Dataset> ds = RenderedSST2(dataset_path);
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In RenderedSST2 dataset, each data dictionary has keys "image" and "label" */
/// auto image = row["image"];
/// \endcode
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
RenderedSST2(const std::string &dataset_dir, const std::string &usage = "all", bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a RenderedSST2Dataset
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the RenderedSST2Dataset.
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
RenderedSST2(const std::string &dataset_dir, const std::string &usage, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a RenderedSST2Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the RenderedSST2Dataset.
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
RenderedSST2(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \class SBUDataset
/// \brief A source dataset that reads and parses SBU dataset.
class DATASET_API SBUDataset : public Dataset {
@ -5268,7 +5361,6 @@ class DATASET_API SST2Dataset : public Dataset {
/// \param[in] shuffle The mode for shuffling data every epoch. Default: ShuffleMode::kGlobal.
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. Default: 1.
/// \param[in] shard_id The shard ID within num_shards. This argument should be

View File

@ -65,6 +65,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
friend class Places365Dataset;
friend class QMnistDataset;
friend class RandomDataDataset;
friend class RenderedSST2Dataset;
friend class SBUDataset;
friend class SemeionDataset;
friend class SpeechCommandsDataset;

View File

@ -63,6 +63,7 @@ __all__ = ["Caltech101Dataset", # Vision
"Places365Dataset", # Vision
"QMnistDataset", # Vision
"RandomDataset", # Vision
"RenderedSST2Dataset", # Vision
"SBDataset", # Vision
"SBUDataset", # Vision
"SemeionDataset", # Vision

View File

@ -37,8 +37,9 @@ from .validators import check_caltech101_dataset, check_caltech256_dataset, chec
check_flickr_dataset, check_flowers102dataset, check_food101_dataset, check_imagefolderdataset, \
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset, check_sun397_dataset
check_random_dataset, check_rendered_sst2_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, \
check_stl10_dataset, check_sun397_dataset, check_svhn_dataset, check_usps_dataset, check_vocdataset, \
check_wider_face_dataset
from ..core.validator_helpers import replace_none
@ -3870,6 +3871,158 @@ class RandomDataset(SourceDataset, VisionBaseDataset):
return cde.RandomNode(self.total_rows, schema, self.columns_list)
class RenderedSST2Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses RenderedSST2 dataset.
The generated dataset has two columns: :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of the uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be 'train', 'val', 'test' or 'all'.
Default: None, will read all samples.
num_samples (int, optional): The number of images to be included in the dataset.
Default: None, will include all images.
num_parallel_workers (int, optional): Number of workers to read the data.
Default: None, set in the config.
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
Default: None, expected order behavior shown in the table below.
decode (bool, optional): Whether or not to decode the images after reading. Default: False.
sampler (Sampler, optional): Object used to choose samples from the
dataset. Default: None, expected order behavior shown in the table below.
num_shards (int, optional): Number of shards that the dataset will be divided
into. When this argument is specified, `num_samples` reflects
the maximum sample number of per shard. Default: None.
shard_id (int, optional): The shard ID within `num_shards` . This
argument can only be specified when `num_shards` is also specified. Default: None.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
Default: None, which means no cache is used.
Raises:
RuntimeError: If `dataset_dir` does not contain data files.
ValueError: If `usage` is not 'train', 'test', 'val' or 'all'.
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
RuntimeError: If `num_shards` is specified but `shard_id` is None.
RuntimeError: If `shard_id` is specified but `num_shards` is None.
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
Note:
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> rendered_sst2_dataset_dir = "/path/to/rendered_sst2_dataset_directory"
>>>
>>> # 1) Read all samples (image files) in rendered_sst2_dataset_dir with 8 threads
>>> dataset = ds.RenderedSST2Dataset(dataset_dir=rendered_sst2_dataset_dir,
usage="all", num_parallel_workers=8)
About RenderedSST2Dataset:
Rendered SST2 is an image classification dataset which was generated by rendering sentences in the Standford
Sentiment Treebank v2 dataset. There are three splits in this dataset and each split contains two classes
(positive and negative): a train split containing 6920 images (3610 positive and 3310 negative), a validation
split containing 872 images (444 positive and 428 negative), and a test split containing 1821 images
(909 positive and 912 negative).
Here is the original RenderedSST2 dataset structure.
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
.. code-block::
.
rendered_sst2_dataset_directory
train
negative
0001.jpg
0002.jpg
...
positive
0001.jpg
0002.jpg
...
test
negative
0001.jpg
0002.jpg
...
positive
0001.jpg
0002.jpg
...
valid
negative
0001.jpg
0002.jpg
...
positive
0001.jpg
0002.jpg
...
Citation:
.. code-block::
@inproceedings{socher-etal-2013-recursive,
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
Christopher D. and Ng, Andrew and Potts, Christopher},
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
month = oct,
year = {2013},
address = {Seattle, Washington, USA},
publisher = {Association for Computational Linguistics},
url = {https://www.aclweb.org/anthology/D13-1170},
pages = {1631--1642},
}
"""
@check_rendered_sst2_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.RenderedSST2Node(self.dataset_dir, self.usage, self.decode, self.sampler)
class _SBDataset:
"""
Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time.
@ -4513,7 +4666,7 @@ class SUN397Dataset(MappableDataset, VisionBaseDataset):
About SUN397Dataset:
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 397 categories with
108,754 images.The number of images varies across categories, but there are at least 100 images per category.
108,754 images. The number of images varies across categories, but there are at least 100 images per category.
Images are in jpg, png, or gif format.
Here is the original SUN397 dataset structure.

View File

@ -1166,6 +1166,34 @@ def check_random_dataset(method):
return new_method
def check_rendered_sst2_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(RenderedSST2Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
usage = param_dict.get('usage')
check_dir(dataset_dir)
if usage is not None:
check_valid_str(usage, ['val', 'all', 'train', 'test'])
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_pad_info(key, val):
"""check the key and value pair of pad_info in batch"""
type_check(key, (str,), "key in pad_info")

View File

@ -0,0 +1,207 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: RenderedSST2Dataset
/// Description: Basic test of RenderedSST2Dataset
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestRenderedSST2Dataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2Dataset.";
// Create a RenderedSST2 Dataset.
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path, "all", false, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: RenderedSST2Dataset
/// Description: Test RenderedSST2Dataset in pipeline mode
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetWithPipeline.";
// Create two RenderedSST2 Dataset.
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
std::shared_ptr<Dataset> ds1 = RenderedSST2(folder_path, "all", false, std::make_shared<RandomSampler>(false, 3));
std::shared_ptr<Dataset> ds2 = RenderedSST2(folder_path, "all", false,std::make_shared<RandomSampler>(false, 3));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: RenderedSST2Dataset
/// Description: Test getting size of RenderedSST2Dataset
/// Expectation: The size is correct
TEST_F(MindDataTestPipeline, TestRenderedSST2GetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2GetDatasetSize.";
// Create a RenderedSST2 Dataset.
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 12);
}
/// Feature: RenderedSST2Dataset
/// Description: Test RenderedSST2Dataset with mix getter
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestRenderedSST2Getters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2MixGetter.";
// Create a RenderedSST2 Dataset.
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 12);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(num_classes, 2);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 12);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 12);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), 2);
EXPECT_EQ(ds->GetDatasetSize(), 12);
}
/// Feature: RenderedSST2Dataset
/// Description: Test RenderedSST2Dataset with the fail of reading dataset
/// Expectation: Throw correct error and message
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetFail.";
// Create a RenderedSST2 Dataset.
std::shared_ptr<Dataset> ds = RenderedSST2("", "train", false, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid RenderedSST2 input.
EXPECT_EQ(iter, nullptr);
}
/// Feature: RenderedSST2Dataset
/// Description: Test RenderedSST2Dataset with the null sampler
/// Expectation: Throw correct error and message
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetWithNullSamplerFail.";
// Create a RenderedSST2 Dataset.
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path, "train", false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid RenderedSST2 input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@ -0,0 +1,222 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test RenderedSST2 dataset operators
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import log as logger
IMAGE_DATA_DIR = "../data/dataset/testRenderedSST2Data"
WRONG_DIR = "../data/dataset/notExist"
def test_rendered_sst2_basic():
"""
Feature: RenderedSST2Dataset
Description: Basic test of RenderedSST2Dataset
Expectation: The data is processed successfully
"""
logger.info("Test RenderedSST2Dataset Op")
# case 1: test read all data
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False)
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False)
num_iter = 0
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["label"], item2["label"])
num_iter += 1
assert num_iter == 12
# case 2: test decode
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True, shuffle=False)
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True, shuffle=False)
num_iter = 0
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["label"], item2["label"])
num_iter += 1
assert num_iter == 12
# case 3: test num_samples
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=4)
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 4
# case 4: test repeat
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=4)
all_data = all_data.repeat(2)
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 8
# case 5: test get_dataset_size, resize and batch
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=8)
all_data = all_data.map(operations=[vision.Decode(), vision.Resize((256, 256))], input_columns=["image"],
num_parallel_workers=1)
assert all_data.get_dataset_size() == 8
assert all_data.get_batch_size() == 1
# drop_remainder is default to be False
all_data = all_data.batch(batch_size=3)
assert all_data.get_batch_size() == 3
assert all_data.get_dataset_size() == 3
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
def test_rendered_sst2_decode():
"""
Feature: RenderedSST2Dataset
Description: Validate RenderedSST2Dataset with decode
Expectation: The data is processed successfully
"""
logger.info("Validate RenderedSST2Dataset with decode")
# define parameters
repeat_count = 1
data1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
def test_rendered_sst2_sequential_sampler():
"""
Feature: RenderedSST2Dataset
Description: Test RenderedSST2Dataset with SequentialSampler
Expectation: The data is processed successfully
"""
logger.info("Test RenderedSST2Dataset Op with SequentialSampler")
num_samples = 4
sampler = ds.SequentialSampler(num_samples=num_samples)
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=sampler)
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_samples=num_samples)
label_list_1, label_list_2 = [], []
num_iter = 0
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
all_data_2.create_dict_iterator(num_epochs=1)):
label_list_1.append(item1["label"].asnumpy())
label_list_2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list_1, label_list_2)
assert num_iter == num_samples
def test_rendered_sst2_random_sampler():
"""
Feature: RenderedSST2Dataset
Description: Test RenderedSST2Dataset with RandomSampler
Expectation: The data is processed successfully
"""
logger.info("Test RenderedSST2Dataset Op with RandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.RandomSampler()
data1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
def test_rendered_sst2_exception():
"""
Feature: RenderedSST2Dataset
Description: Test error cases for RenderedSST2Dataset
Expectation: Throw correct error and message
"""
logger.info("Test error cases for RenderedSST2Dataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, sampler=ds.SequentialSampler(1))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=ds.SequentialSampler(1), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=2, shard_id="0")
error_msg_8 = "does not exist or is not a directory or permission denied!"
with pytest.raises(ValueError, match=error_msg_8):
all_data = ds.RenderedSST2Dataset(WRONG_DIR)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
if __name__ == '__main__':
test_rendered_sst2_basic()
test_rendered_sst2_decode()
test_rendered_sst2_sequential_sampler()
test_rendered_sst2_random_sampler()
test_rendered_sst2_exception()