[assistant][ops]New operator implementation, include RenderedSST2Dataset
|
@ -0,0 +1,118 @@
|
|||
mindspore.dataset.RenderedSST2Dataset
|
||||
=====================================
|
||||
|
||||
.. py:class:: mindspore.dataset.RenderedSST2Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
|
||||
|
||||
读取和解析RenderedSST2数据集的源文件构建数据集。
|
||||
|
||||
生成的数据集有两列 `[image, label]` 。 `image` 列的数据类型为uint8。`label` 列的数据类型为uint32。
|
||||
|
||||
参数:
|
||||
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选) - 指定数据集的子集,可取值为'train'、'val'、'test'或'all'。默认值:None,读取全部样本图片。
|
||||
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值:None,读取全部样本图片。
|
||||
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
|
||||
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值:None,下表中会展示不同参数配置的预期行为。
|
||||
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值:False,不解码。
|
||||
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值:None,下表中会展示不同配置的预期行为。
|
||||
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
|
||||
- **ValueError** - `usage` 参数取值不为'train'、'val'、'test'或'all'。
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **ValueError** - `shard_id` 参数值错误,小于0或者大于等于 `num_shards` 。
|
||||
|
||||
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
|
||||
|
||||
.. list-table:: 配置 `sampler` 和 `shuffle` 的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 参数 `sampler`
|
||||
- 参数 `shuffle`
|
||||
- 预期数据顺序
|
||||
* - None
|
||||
- None
|
||||
- 随机排列
|
||||
* - None
|
||||
- True
|
||||
- 随机排列
|
||||
* - None
|
||||
- False
|
||||
- 顺序排列
|
||||
* - `sampler` 实例
|
||||
- None
|
||||
- 由 `sampler` 行为定义的顺序
|
||||
* - `sampler` 实例
|
||||
- True
|
||||
- 不允许
|
||||
* - `sampler` 实例
|
||||
- False
|
||||
- 不允许
|
||||
|
||||
**关于RenderedSST2数据集:**
|
||||
|
||||
Rendered SST2是一个图像分类数据集,它是由SST2数据集中的数据生成的。数据集被分割成三份,每一份包含有两类(positive和negative):
|
||||
在train这一份下共有6920张图像(3610张positive,3310张negative),在validation这一份下共有872张图像(444张positive,428张negative),
|
||||
在test这一份下共有1821张图像(909张positive,912张negative)。
|
||||
|
||||
以下为原始RenderedSST2数据集的结构,您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── rendered_sst2_dataset_directory
|
||||
├── train
|
||||
│ ├── negative
|
||||
│ │ ├── 0001.jpg
|
||||
│ │ ├── 0002.jpg
|
||||
│ │ ...
|
||||
│ └── positive
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
├── test
|
||||
│ ├── negative
|
||||
│ │ ├── 0001.jpg
|
||||
│ │ ├── 0002.jpg
|
||||
│ │ ...
|
||||
│ └── positive
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
└── valid
|
||||
├── negative
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
└── positive
|
||||
├── 0001.jpg
|
||||
├── 0002.jpg
|
||||
...
|
||||
|
||||
**引用:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{socher-etal-2013-recursive,
|
||||
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
|
||||
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
|
||||
Christopher D. and Ng, Andrew and Potts, Christopher},
|
||||
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
|
||||
month = oct,
|
||||
year = {2013},
|
||||
address = {Seattle, Washington, USA},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
url = {https://www.aclweb.org/anthology/D13-1170},
|
||||
pages = {1631--1642},
|
||||
}
|
||||
|
||||
|
||||
.. include:: mindspore.dataset.api_list_vision.rst
|
|
@ -120,6 +120,7 @@ mindspore.dataset
|
|||
mindspore.dataset.PhotoTourDataset
|
||||
mindspore.dataset.Places365Dataset
|
||||
mindspore.dataset.QMnistDataset
|
||||
mindspore.dataset.RenderedSST2Dataset
|
||||
mindspore.dataset.SBDataset
|
||||
mindspore.dataset.SBUDataset
|
||||
mindspore.dataset.SemeionDataset
|
||||
|
|
|
@ -32,6 +32,7 @@ Vision
|
|||
mindspore.dataset.PhotoTourDataset
|
||||
mindspore.dataset.Places365Dataset
|
||||
mindspore.dataset.QMnistDataset
|
||||
mindspore.dataset.RenderedSST2Dataset
|
||||
mindspore.dataset.SBDataset
|
||||
mindspore.dataset.SBUDataset
|
||||
mindspore.dataset.SemeionDataset
|
||||
|
|
|
@ -116,6 +116,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h"
|
||||
|
@ -2005,6 +2006,33 @@ RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vecto
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds =
|
||||
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds =
|
||||
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds =
|
||||
std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
|
||||
|
@ -642,6 +643,18 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RenderedSST2Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RenderedSST2Node, DatasetNode, std::shared_ptr<RenderedSST2Node>>(
|
||||
*m, "RenderedSST2Node", "to create a RenderedSST2Node")
|
||||
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const py::handle &sampler) {
|
||||
auto rendered_sst2 =
|
||||
std::make_shared<RenderedSST2Node>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(rendered_sst2->ValidateParams());
|
||||
return rendered_sst2;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
|
||||
"to create an SBUNode")
|
||||
|
|
|
@ -44,6 +44,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
places365_op.cc
|
||||
qmnist_op.cc
|
||||
random_data_op.cc
|
||||
rendered_sst2_op.cc
|
||||
sbu_op.cc
|
||||
semeion_op.cc
|
||||
sogou_news_op.cc
|
||||
|
|
|
@ -0,0 +1,373 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/rendered_sst2_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
RenderedSST2Op::RenderedSST2Op(int32_t num_wkrs, const std::string &file_dir, const std::string &usage,
|
||||
int32_t queue_size, bool do_decode, const std::set<std::string> &exts,
|
||||
const std::map<std::string, uint32_t> &map, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_wkrs, queue_size, std::move(sampler)),
|
||||
folder_path_(file_dir),
|
||||
usage_(usage),
|
||||
decode_(do_decode),
|
||||
extensions_(exts),
|
||||
class_index_(map),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_ind_(0) {
|
||||
folder_path_queue_ = std::make_unique<Queue<std::string>>(num_wkrs * queue_size);
|
||||
folder_classId_queue_ = std::make_unique<Queue<uint32_t>>(num_wkrs * queue_size);
|
||||
image_name_queue_ = std::make_unique<Queue<FolderImagesPair>>(num_wkrs * queue_size);
|
||||
}
|
||||
|
||||
// Master thread that pulls the prescan worker's results.
|
||||
// Keep collecting results until all prescan workers quit
|
||||
// Then consolidate 2 level shuffles together into 1 giant vector
|
||||
// calculate numRows then return
|
||||
Status RenderedSST2Op::PrepareData() {
|
||||
std::vector<FolderImagesPair> v;
|
||||
int64_t cnt = 0;
|
||||
while (cnt != num_workers_) { // count number of end signals
|
||||
FolderImagesPair p;
|
||||
RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p));
|
||||
if (p == nullptr) {
|
||||
cnt++;
|
||||
} else {
|
||||
v.push_back(p);
|
||||
}
|
||||
}
|
||||
std::sort(v.begin(), v.end(),
|
||||
[](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; });
|
||||
// following loop puts the 2 level of shuffles together into 1 vector
|
||||
for (size_t ind = 0; ind < v.size(); ++ind) {
|
||||
while (!v[ind]->second.empty()) {
|
||||
image_label_pairs_.push_back(v[ind]->second.front());
|
||||
image_prefix_.push_back(v[ind]->first);
|
||||
v[ind]->second.pop();
|
||||
}
|
||||
}
|
||||
image_label_pairs_.shrink_to_fit();
|
||||
num_rows_ = image_label_pairs_.size();
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||
DatasetName() + " file path: " + folder_path_);
|
||||
}
|
||||
// free memory of two queues used for pre-scan
|
||||
folder_path_queue_->Reset();
|
||||
folder_classId_queue_->Reset();
|
||||
image_name_queue_->Reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status RenderedSST2Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(image_prefix_[row_id] + pair_ptr->first, &image));
|
||||
|
||||
if (decode_) {
|
||||
Status rc = Decode(image, &image);
|
||||
if (rc.IsError()) {
|
||||
std::string err = "Invalid image, " + folder_path_ + (pair_ptr->first) +
|
||||
" decode failed, the image is broken or permission denied.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RenderedSST2Op::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\n"
|
||||
<< DatasetName(true) << " directory: " << folder_path_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status RenderedSST2Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
||||
if (image_label_pairs_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid dataset_dir, " + DatasetName(true) +
|
||||
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||
DatasetName() + " file path: " + folder_path_);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR], Map containing image-index pair is nullptr or has been set in other place,"
|
||||
"it must be empty before using GetClassIds.");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < image_label_pairs_.size(); ++i) {
|
||||
(*cls_ids)[image_label_pairs_[i]->second].push_back(i);
|
||||
}
|
||||
for (auto &pair : (*cls_ids)) {
|
||||
pair.second.shrink_to_fit();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Worker Entry for pre-scanning all the folders and do the 1st level shuffle
|
||||
// Worker pull a file path from folder_path_queue_ (which is a Queue), walks all the images under that folderpath
|
||||
// After walking is complete, sort all the file names (relative path to all png files under the same directory )
|
||||
// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree)
|
||||
// Add the sorted filenames in to a queue. The make a pair (folderpath, queue<filenames>*),
|
||||
// folderpath is used for 2nd level sorting.
|
||||
// FYI: 1st level sorting: sort all images under the same directory.
|
||||
// FYI: 2nd level sorting: sort all folder names
|
||||
// push this pair to image_name_queue (which is again a Queue)
|
||||
Status RenderedSST2Op::PrescanWorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::string folder_;
|
||||
uint32_t current_class_id;
|
||||
RETURN_IF_NOT_OK(folder_path_queue_->PopFront(&folder_));
|
||||
RETURN_IF_NOT_OK(folder_classId_queue_->PopFront(¤t_class_id));
|
||||
while (!folder_.empty()) {
|
||||
Path folder(folder_);
|
||||
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
|
||||
if (!folder.Exists() || dirItr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid dataset_dir, " + folder_ + " does not exist or permission denied.");
|
||||
}
|
||||
auto offset_ = folder_.size();
|
||||
std::set<std::string> imgs; // use this for ordering
|
||||
while (dirItr->HasNext()) {
|
||||
Path file = dirItr->Next();
|
||||
if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) {
|
||||
(void)imgs.insert(file.ToString().substr(offset_));
|
||||
} else {
|
||||
MS_LOG(WARNING) << DatasetName(true) << " operator unsupported file found: " << file.ToString()
|
||||
<< ", extension: " << file.Extension() << ".";
|
||||
}
|
||||
}
|
||||
FolderImagesPair p = std::make_shared<std::pair<std::string, std::queue<ImageLabelPair>>>();
|
||||
p->first = folder_;
|
||||
for (const std::string &img : imgs) {
|
||||
p->second.push(std::make_shared<std::pair<std::string, uint32_t>>(img, current_class_id));
|
||||
}
|
||||
RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p));
|
||||
RETURN_IF_NOT_OK(folder_path_queue_->PopFront(&folder_));
|
||||
RETURN_IF_NOT_OK(folder_classId_queue_->PopFront(¤t_class_id));
|
||||
}
|
||||
RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This helper function walks all folder_paths, and send each folderpath to folder_path_queue_
|
||||
Status RenderedSST2Op::WalkFolder(Path *dir) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dir);
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir);
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_itr);
|
||||
auto offset_ = dir->ToString().size();
|
||||
std::string current_class;
|
||||
while (dir_itr->HasNext()) {
|
||||
Path subdir = dir_itr->Next();
|
||||
if (subdir.IsDirectory()) {
|
||||
RETURN_IF_NOT_OK(folder_path_queue_->EmplaceBack(subdir.ToString()));
|
||||
current_class = subdir.ToString().substr(offset_ + 1);
|
||||
if (class_index_.find(current_class) == class_index_.end()) {
|
||||
class_index_[current_class] = class_index_.size();
|
||||
}
|
||||
RETURN_IF_NOT_OK(folder_classId_queue_->EmplaceBack(class_index_[current_class]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A thread that calls WalkFolder
|
||||
Status RenderedSST2Op::StartAsyncWalk() {
|
||||
TaskManager::FindMe()->Post();
|
||||
Path dir(folder_path_);
|
||||
|
||||
if (!dir.Exists() || !dir.IsDirectory()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, " + folder_path_ + " may not exist or is not a directory.");
|
||||
}
|
||||
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_itr);
|
||||
auto offset_ = folder_path_.length();
|
||||
while (dir_itr->HasNext()) {
|
||||
Path subdir = dir_itr->Next();
|
||||
if (subdir.IsDirectory()) {
|
||||
std::string name = subdir.ToString().substr(offset_ + 1);
|
||||
if (usage_ == name) {
|
||||
RETURN_IF_NOT_OK(WalkFolder(&subdir));
|
||||
} else if (usage_ == "val" && name == "valid") {
|
||||
RETURN_IF_NOT_OK(WalkFolder(&subdir));
|
||||
} else if (usage_ == "all" && (name == "train" || name == "test" || name == "valid")) {
|
||||
RETURN_IF_NOT_OK(WalkFolder(&subdir));
|
||||
}
|
||||
}
|
||||
}
|
||||
// send out num_workers_ end signal to folder_path_queue_, 1 for each worker.
|
||||
// Upon receiving end Signal, worker quits and set another end Signal to image_name_queue.
|
||||
for (int32_t ind = 0; ind < num_workers_; ++ind) {
|
||||
RETURN_IF_NOT_OK(folder_path_queue_->EmplaceBack("")); // end signal
|
||||
RETURN_IF_NOT_OK(folder_classId_queue_->EmplaceBack(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Op::RegisterAndLaunchThreads() {
|
||||
RETURN_IF_NOT_OK(ParallelOp::RegisterAndLaunchThreads());
|
||||
RETURN_IF_NOT_OK(folder_path_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(folder_classId_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks()));
|
||||
|
||||
// The following code launch 3 threads group
|
||||
// 1) A thread that walks all folders and push the folder names to a util:Queue folder_path_queue_.
|
||||
// 2) Workers that pull foldername from folder_path_queue_, walk it and return the sorted images to image_name_queue
|
||||
// 3) Launch main workers that load TensorRows by reading all images
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(Name() + "::WalkDir",
|
||||
std::bind(&RenderedSST2Op::StartAsyncWalk, this), nullptr, id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
|
||||
std::bind(&RenderedSST2Op::PrescanWorkerEntry, this, std::placeholders::_1),
|
||||
Name() + "::PrescanWorkerEntry", id()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Op::WalkFolderForCountRows(Path *dir, std::queue<std::string> *folder_paths,
|
||||
std::map<std::string, uint32_t> *class_index) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dir);
|
||||
RETURN_UNEXPECTED_IF_NULL(folder_paths);
|
||||
RETURN_UNEXPECTED_IF_NULL(class_index);
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir);
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_itr);
|
||||
std::string current_class;
|
||||
auto offset_ = dir->ToString().size();
|
||||
while (dir_itr->HasNext()) {
|
||||
Path subdir = dir_itr->Next();
|
||||
if (subdir.IsDirectory()) {
|
||||
folder_paths->push(subdir.ToString());
|
||||
current_class = subdir.ToString().substr(offset_ + 1);
|
||||
if (class_index->find(current_class) == class_index->end()) {
|
||||
(*class_index)[current_class] = class_index->size();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Op::CountRows(std::queue<std::string> *folder_paths, int64_t *num_rows,
|
||||
const std::set<std::string> &exts) {
|
||||
RETURN_UNEXPECTED_IF_NULL(folder_paths);
|
||||
RETURN_UNEXPECTED_IF_NULL(num_rows);
|
||||
int64_t row_cnt = 0;
|
||||
while (!folder_paths->empty()) {
|
||||
Path subdir(folder_paths->front());
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&subdir);
|
||||
if (!subdir.Exists() || dir_itr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid subdirectory, RenderedSST2 Dataset subdirectory: " + subdir.ToString() +
|
||||
" does not exist or permission denied");
|
||||
}
|
||||
while (dir_itr->HasNext()) {
|
||||
if (exts.empty() || exts.find(dir_itr->Next().Extension()) != exts.end()) {
|
||||
++row_cnt;
|
||||
}
|
||||
}
|
||||
folder_paths->pop();
|
||||
}
|
||||
(*num_rows) = row_cnt;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Op::CountRowsAndClasses(const std::string &path, const std::string &usage,
|
||||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes) {
|
||||
Path dir(path);
|
||||
std::string err_msg = "";
|
||||
err_msg += (!dir.Exists() || !dir.IsDirectory())
|
||||
? "Invalid dataset_dir, " + path + " does not exist or the path is not a directory. "
|
||||
: "";
|
||||
err_msg += (num_classes == nullptr && num_rows == nullptr) ? "[Internal ERROR] num_class and num_rows are null." : "";
|
||||
if (!err_msg.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::queue<std::string> folder_paths;
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
|
||||
std::map<std::string, uint32_t> class_index;
|
||||
auto offset_ = path.size();
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_itr);
|
||||
|
||||
while (dir_itr->HasNext()) {
|
||||
Path subdir = dir_itr->Next();
|
||||
if (subdir.IsDirectory()) {
|
||||
std::string name = subdir.ToString().substr(offset_ + 1);
|
||||
name = name == "valid" ? "val" : name;
|
||||
if (usage == name) {
|
||||
RETURN_IF_NOT_OK(WalkFolderForCountRows(&subdir, &folder_paths, &class_index));
|
||||
} else if (usage == "all" && (name == "train" || name == "test" || name == "val")) {
|
||||
RETURN_IF_NOT_OK(WalkFolderForCountRows(&subdir, &folder_paths, &class_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_classes != nullptr) {
|
||||
*num_classes = class_index.size();
|
||||
}
|
||||
// return here if only num_class is needed
|
||||
RETURN_OK_IF_TRUE(num_rows == nullptr);
|
||||
|
||||
RETURN_IF_NOT_OK(CountRows(&folder_paths, num_rows, exts));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Op::ComputeColMap() {
|
||||
// Set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get number of classes
|
||||
Status RenderedSST2Op::GetNumClasses(int64_t *num_classes) {
|
||||
RETURN_UNEXPECTED_IF_NULL(num_classes);
|
||||
if (num_classes_ > 0) {
|
||||
*num_classes = num_classes_;
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, usage_, extensions_, nullptr, num_classes));
|
||||
num_classes_ = *num_classes;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,172 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Forward declares
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
using ImageLabelPair = std::shared_ptr<std::pair<std::string, uint32_t>>;
|
||||
using FolderImagesPair = std::shared_ptr<std::pair<std::string, std::queue<ImageLabelPair>>>;
|
||||
|
||||
class RenderedSST2Op : public MappableLeafOp {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param int32_t num_wkrs - Num of workers reading images in parallel.
|
||||
/// @param const std::string &file_dir - Directory of RenderedSST2Dataset.
|
||||
/// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'val' or 'all'.
|
||||
/// @param int32_t queue_size - Connector queue size.
|
||||
/// @param bool do_decode - Decode the images after reading.
|
||||
/// @param std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the dir.
|
||||
/// @param std::map<std::string, int32_t> &map- Map of class name and class id.
|
||||
/// @param std::unique_ptr<dataschema> data_schema - Schema of data.
|
||||
/// @param std::shared_ptr<SamplerRT> sampler - Sampler tells RenderedSST2Op what to read.
|
||||
RenderedSST2Op(int32_t num_wkrs, const std::string &file_dir, const std::string &usage, int32_t queue_size,
|
||||
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, uint32_t> &map,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// Destructor.
|
||||
~RenderedSST2Op() override = default;
|
||||
|
||||
/// Initialize RenderedSST2Op related var, calls the function to walk all files.
|
||||
/// @return Status The status code returned.
|
||||
Status PrepareData() override;
|
||||
|
||||
/// Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector.
|
||||
/// @param int32_t worker_id - Id of each worker.
|
||||
/// @return Status The status code returned.
|
||||
Status PrescanWorkerEntry(int32_t worker_id);
|
||||
|
||||
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
|
||||
/// @param (std::map<int32_t, std::vector<int64_t >> * cls_ids - Key label, val all ids for this class.
|
||||
/// @return Status The status code returned.
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
/// A print method typically used for debugging.
|
||||
/// @param std::ostream &out - Out stream.
|
||||
/// @param bool show_all - Whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// This function is a hack! It is to return the num_class and num_rows. The result
|
||||
/// returned by this function may not be consistent with what RenderedSST2Op is going to return
|
||||
/// user this at your own risk!
|
||||
/// @param const std::string &path - Directory of RenderedSST2Dataset.
|
||||
/// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'valid' or 'all'.
|
||||
/// @param const std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the
|
||||
/// dir.
|
||||
/// @param int64_t *num_rows - The number of rows.
|
||||
/// @param int64_t *num_classes - The number of classes.
|
||||
/// @return Status of the function.
|
||||
static Status CountRowsAndClasses(const std::string &path, const std::string &usage,
|
||||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes);
|
||||
|
||||
/// This help function is used to count the num_rows.
|
||||
/// @param std::queue<std::string> *folder_paths - A queue contains all the image folder paths.
|
||||
/// @param int64_t *num_rows - The number of rows.
|
||||
/// @param const std::set<std::string> &exts - Set of file extensions to read, if empty, read everything under the
|
||||
/// dir.
|
||||
/// @return Status of the function.
|
||||
static Status CountRows(std::queue<std::string> *folder_paths, int64_t *num_rows, const std::set<std::string> &exts);
|
||||
|
||||
/// Op name getter.
|
||||
/// @return Name of the current Op.
|
||||
std::string Name() const override { return "RenderedSST2Op"; }
|
||||
|
||||
/// DatasetName name getter.
|
||||
/// @return DatasetName of the current Op.
|
||||
virtual std::string DatasetName(bool upper = false) const { return upper ? "RenderedSST2" : "rendered sst2"; }
|
||||
|
||||
/// Base-class override for GetNumClasses.
|
||||
/// @param num_classes - The number of classes.
|
||||
/// @return Status of the function.
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
protected:
|
||||
/// Load a tensor row according to a pair.
|
||||
/// @param row_id_type row_id - Id for this tensor row.
|
||||
/// @param TensorRow *row - Image & label read into this tensor row.
|
||||
/// @return Status The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
/// @param Path * dir - Dir to walk all folders.
|
||||
/// @return Status The status code returned.
|
||||
virtual Status WalkFolder(Path *dir);
|
||||
|
||||
/// @param Path * dir - Dir to walk all images.
|
||||
/// @param std::queue<std::string> *folder_paths - A queue contains all the image folder paths.
|
||||
/// @param std::map<std::string, int32_t> *class_index - A map records the class and the class's Id.
|
||||
/// @return Status The status code returned.
|
||||
static Status WalkFolderForCountRows(Path *dir, std::queue<std::string> *folder_paths,
|
||||
std::map<std::string, uint32_t> *class_index);
|
||||
|
||||
/// start walking of all dirs.
|
||||
/// @return Status The status code returned.
|
||||
Status StartAsyncWalk();
|
||||
|
||||
/// Called first when function is called.
|
||||
/// @return Status The status code returned.
|
||||
Status RegisterAndLaunchThreads() override;
|
||||
|
||||
/// Private function for computing the assignment of the column name map.
|
||||
/// @return Status The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
std::string folder_path_; // directory of image folder
|
||||
std::string usage_;
|
||||
bool recursive_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_; // extensions allowed
|
||||
std::map<std::string, uint32_t> class_index_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
int64_t sampler_ind_;
|
||||
std::vector<ImageLabelPair> image_label_pairs_;
|
||||
std::vector<std::string> image_prefix_;
|
||||
std::unique_ptr<Queue<std::string>> folder_path_queue_;
|
||||
std::unique_ptr<Queue<uint32_t>> folder_classId_queue_; // the class Id of the images under the folder
|
||||
std::unique_ptr<Queue<FolderImagesPair>> image_name_queue_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RENDERED_SST2_OP_H_
|
|
@ -118,6 +118,7 @@ constexpr char kPhotoTourNode[] = "PhotoTourDataset";
|
|||
constexpr char kPlaces365Node[] = "Places365Dataset";
|
||||
constexpr char kQMnistNode[] = "QMnistDataset";
|
||||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kRenderedSST2Node[] = "RenderedSST2Dataset";
|
||||
constexpr char kSBUNode[] = "SBUDataset";
|
||||
constexpr char kSemeionNode[] = "SemeionDataset";
|
||||
constexpr char kSogouNewsNode[] = "SogouNewsDataset";
|
||||
|
|
|
@ -45,6 +45,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
places365_node.cc
|
||||
qmnist_node.cc
|
||||
random_node.cc
|
||||
rendered_sst2_node.cc
|
||||
sbu_node.cc
|
||||
semeion_node.cc
|
||||
sogou_news_node.cc
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/rendered_sst2_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const std::set<std::string> kExts = {".png"};
|
||||
|
||||
RenderedSST2Node::RenderedSST2Node(const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> RenderedSST2Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<RenderedSST2Node>(dataset_dir_, usage_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RenderedSST2Node::Print(std::ostream &out) const {
|
||||
out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status RenderedSST2Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("RenderedSST2Node", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("RenderedSST2Node", usage_, {"val", "train", "all", "test"}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("RenderedSST2Node", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
// This arg exists in RenderedSST2Op, but is not externalized (in Python API).
|
||||
RETURN_UNEXPECTED_IF_NULL(node_ops);
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
const std::map<std::string, uint32_t> kClassIndex = {};
|
||||
auto op = std::make_shared<RenderedSST2Op>(num_workers_, dataset_dir_, usage_, connector_que_size_, decode_, kExts,
|
||||
kClassIndex, std::move(schema), std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node.
|
||||
Status RenderedSST2Node::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size.
|
||||
Status RenderedSST2Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_size);
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size, num_rows;
|
||||
RETURN_IF_NOT_OK(RenderedSST2Op::CountRowsAndClasses(dataset_dir_, usage_, kExts, &num_rows, nullptr));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
RETURN_UNEXPECTED_IF_NULL(size_getter);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenderedSST2Node::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["decode"] = decode_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status RenderedSST2Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kRenderedSST2Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kRenderedSST2Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kRenderedSST2Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kRenderedSST2Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kRenderedSST2Node));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<RenderedSST2Node>(dataset_dir, usage, decode, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \class RenderedSST2Node.
|
||||
/// \brief A Dataset derived class to represent RenderedSST2 dataset.
|
||||
class RenderedSST2Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
RenderedSST2Node(const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~RenderedSST2Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kRenderedSST2Node; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[out] shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size The size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
/// \return The dataset_dir
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
/// \return The usage
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
/// \return The Decode.
|
||||
bool Decode() const { return decode_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json.
|
||||
/// \param[in] json_obj The JSON object to be deserialized.
|
||||
/// \param[out] ds Deserialized dataset.
|
||||
/// \return Status The status code returned.
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENDERED_SST2_NODE_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
|
|
@ -4882,6 +4882,99 @@ std::shared_ptr<RandomDataDataset> DATASET_API RandomData(const int32_t &total_r
|
|||
return ds;
|
||||
}
|
||||
|
||||
/// \class RenderedSST2Dataset
|
||||
/// \brief A source dataset for reading and parsing RenderedSST2 dataset.
|
||||
class DATASET_API RenderedSST2Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of RenderedSST2Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of RenderedSST2Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of RenderedSST2Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of RenderedSST2Dataset.
|
||||
~RenderedSST2Dataset() override = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a RenderedSST2Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all". Default: "all".
|
||||
/// \param[in] decode Decode the images after reading. Default: false.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset. Default: RandomSampler().
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the RenderedSST2Dataset.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and MindData object */
|
||||
/// std::string dataset_path = "/path/to/RenderedSST2_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = RenderedSST2(dataset_path);
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In RenderedSST2 dataset, each data dictionary has keys "image" and "label" */
|
||||
/// auto image = row["image"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
|
||||
RenderedSST2(const std::string &dataset_dir, const std::string &usage = "all", bool decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a RenderedSST2Dataset
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the RenderedSST2Dataset.
|
||||
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
|
||||
RenderedSST2(const std::string &dataset_dir, const std::string &usage, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a RenderedSST2Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test", "val" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the RenderedSST2Dataset.
|
||||
inline std::shared_ptr<RenderedSST2Dataset> DATASET_API
|
||||
RenderedSST2(const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<RenderedSST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class SBUDataset
|
||||
/// \brief A source dataset that reads and parses SBU dataset.
|
||||
class DATASET_API SBUDataset : public Dataset {
|
||||
|
@ -5268,7 +5361,6 @@ class DATASET_API SST2Dataset : public Dataset {
|
|||
/// \param[in] shuffle The mode for shuffling data every epoch. Default: ShuffleMode::kGlobal.
|
||||
/// Can be any of:
|
||||
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||
/// ShuffleMode::kFiles - Shuffle files only.
|
||||
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into. Default: 1.
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
|
|
|
@ -65,6 +65,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class Places365Dataset;
|
||||
friend class QMnistDataset;
|
||||
friend class RandomDataDataset;
|
||||
friend class RenderedSST2Dataset;
|
||||
friend class SBUDataset;
|
||||
friend class SemeionDataset;
|
||||
friend class SpeechCommandsDataset;
|
||||
|
|
|
@ -63,6 +63,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"Places365Dataset", # Vision
|
||||
"QMnistDataset", # Vision
|
||||
"RandomDataset", # Vision
|
||||
"RenderedSST2Dataset", # Vision
|
||||
"SBDataset", # Vision
|
||||
"SBUDataset", # Vision
|
||||
"SemeionDataset", # Vision
|
||||
|
|
|
@ -37,8 +37,9 @@ from .validators import check_caltech101_dataset, check_caltech256_dataset, chec
|
|||
check_flickr_dataset, check_flowers102dataset, check_food101_dataset, check_imagefolderdataset, \
|
||||
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
|
||||
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
|
||||
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
|
||||
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset, check_sun397_dataset
|
||||
check_random_dataset, check_rendered_sst2_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, \
|
||||
check_stl10_dataset, check_sun397_dataset, check_svhn_dataset, check_usps_dataset, check_vocdataset, \
|
||||
check_wider_face_dataset
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
|
||||
|
@ -3870,6 +3871,158 @@ class RandomDataset(SourceDataset, VisionBaseDataset):
|
|||
return cde.RandomNode(self.total_rows, schema, self.columns_list)
|
||||
|
||||
|
||||
class RenderedSST2Dataset(MappableDataset, VisionBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses RenderedSST2 dataset.
|
||||
|
||||
The generated dataset has two columns: :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is of the uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be 'train', 'val', 'test' or 'all'.
|
||||
Default: None, will read all samples.
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
Default: None, will include all images.
|
||||
num_parallel_workers (int, optional): Number of workers to read the data.
|
||||
Default: None, set in the config.
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
|
||||
Default: None, expected order behavior shown in the table below.
|
||||
decode (bool, optional): Whether or not to decode the images after reading. Default: False.
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset. Default: None, expected order behavior shown in the table below.
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into. When this argument is specified, `num_samples` reflects
|
||||
the maximum sample number of per shard. Default: None.
|
||||
shard_id (int, optional): The shard ID within `num_shards` . This
|
||||
argument can only be specified when `num_shards` is also specified. Default: None.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
|
||||
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
|
||||
Default: None, which means no cache is used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `dataset_dir` does not contain data files.
|
||||
ValueError: If `usage` is not 'train', 'test', 'val' or 'all'.
|
||||
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
|
||||
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
|
||||
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
|
||||
RuntimeError: If `num_shards` is specified but `shard_id` is None.
|
||||
RuntimeError: If `shard_id` is specified but `num_shards` is None.
|
||||
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> rendered_sst2_dataset_dir = "/path/to/rendered_sst2_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Read all samples (image files) in rendered_sst2_dataset_dir with 8 threads
|
||||
>>> dataset = ds.RenderedSST2Dataset(dataset_dir=rendered_sst2_dataset_dir,
|
||||
usage="all", num_parallel_workers=8)
|
||||
|
||||
About RenderedSST2Dataset:
|
||||
|
||||
Rendered SST2 is an image classification dataset which was generated by rendering sentences in the Standford
|
||||
Sentiment Treebank v2 dataset. There are three splits in this dataset and each split contains two classes
|
||||
(positive and negative): a train split containing 6920 images (3610 positive and 3310 negative), a validation
|
||||
split containing 872 images (444 positive and 428 negative), and a test split containing 1821 images
|
||||
(909 positive and 912 negative).
|
||||
|
||||
Here is the original RenderedSST2 dataset structure.
|
||||
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── rendered_sst2_dataset_directory
|
||||
├── train
|
||||
│ ├── negative
|
||||
│ │ ├── 0001.jpg
|
||||
│ │ ├── 0002.jpg
|
||||
│ │ ...
|
||||
│ └── positive
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
├── test
|
||||
│ ├── negative
|
||||
│ │ ├── 0001.jpg
|
||||
│ │ ├── 0002.jpg
|
||||
│ │ ...
|
||||
│ └── positive
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
└── valid
|
||||
├── negative
|
||||
│ ├── 0001.jpg
|
||||
│ ├── 0002.jpg
|
||||
│ ...
|
||||
└── positive
|
||||
├── 0001.jpg
|
||||
├── 0002.jpg
|
||||
...
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{socher-etal-2013-recursive,
|
||||
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
|
||||
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
|
||||
Christopher D. and Ng, Andrew and Potts, Christopher},
|
||||
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
|
||||
month = oct,
|
||||
year = {2013},
|
||||
address = {Seattle, Washington, USA},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
url = {https://www.aclweb.org/anthology/D13-1170},
|
||||
pages = {1631--1642},
|
||||
}
|
||||
"""
|
||||
|
||||
@check_rendered_sst2_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.RenderedSST2Node(self.dataset_dir, self.usage, self.decode, self.sampler)
|
||||
|
||||
|
||||
class _SBDataset:
|
||||
"""
|
||||
Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time.
|
||||
|
@ -4513,7 +4666,7 @@ class SUN397Dataset(MappableDataset, VisionBaseDataset):
|
|||
About SUN397Dataset:
|
||||
|
||||
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 397 categories with
|
||||
108,754 images.The number of images varies across categories, but there are at least 100 images per category.
|
||||
108,754 images. The number of images varies across categories, but there are at least 100 images per category.
|
||||
Images are in jpg, png, or gif format.
|
||||
|
||||
Here is the original SUN397 dataset structure.
|
||||
|
|
|
@ -1166,6 +1166,34 @@ def check_random_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_rendered_sst2_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(RenderedSST2Dataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
usage = param_dict.get('usage')
|
||||
check_dir(dataset_dir)
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ['val', 'all', 'train', 'test'])
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_pad_info(key, val):
|
||||
"""check the key and value pair of pad_info in batch"""
|
||||
type_check(key, (str,), "key in pad_info")
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Basic test of RenderedSST2Dataset
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2Dataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2Dataset.";
|
||||
|
||||
// Create a RenderedSST2 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
|
||||
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path, "all", false, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Test RenderedSST2Dataset in pipeline mode
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetWithPipeline.";
|
||||
|
||||
// Create two RenderedSST2 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
|
||||
std::shared_ptr<Dataset> ds1 = RenderedSST2(folder_path, "all", false, std::make_shared<RandomSampler>(false, 3));
|
||||
std::shared_ptr<Dataset> ds2 = RenderedSST2(folder_path, "all", false,std::make_shared<RandomSampler>(false, 3));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds.
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds.
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds.
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Test getting size of RenderedSST2Dataset
|
||||
/// Expectation: The size is correct
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2GetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2GetDatasetSize.";
|
||||
|
||||
// Create a RenderedSST2 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
|
||||
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 12);
|
||||
}
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Test RenderedSST2Dataset with mix getter
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2Getters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2MixGetter.";
|
||||
|
||||
// Create a RenderedSST2 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
|
||||
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 12);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(num_classes, 2);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 12);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), 2);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 12);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), 2);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 12);
|
||||
}
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Test RenderedSST2Dataset with the fail of reading dataset
|
||||
/// Expectation: Throw correct error and message
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetFail.";
|
||||
|
||||
// Create a RenderedSST2 Dataset.
|
||||
std::shared_ptr<Dataset> ds = RenderedSST2("", "train", false, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid RenderedSST2 input.
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: RenderedSST2Dataset
|
||||
/// Description: Test RenderedSST2Dataset with the null sampler
|
||||
/// Expectation: Throw correct error and message
|
||||
TEST_F(MindDataTestPipeline, TestRenderedSST2DatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenderedSST2DatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a RenderedSST2 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testRenderedSST2Data/";
|
||||
std::shared_ptr<Dataset> ds = RenderedSST2(folder_path, "train", false, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid RenderedSST2 input, sampler cannot be nullptr.
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 27 KiB |
|
@ -0,0 +1,222 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test RenderedSST2 dataset operators
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
IMAGE_DATA_DIR = "../data/dataset/testRenderedSST2Data"
|
||||
WRONG_DIR = "../data/dataset/notExist"
|
||||
|
||||
|
||||
def test_rendered_sst2_basic():
|
||||
"""
|
||||
Feature: RenderedSST2Dataset
|
||||
Description: Basic test of RenderedSST2Dataset
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test RenderedSST2Dataset Op")
|
||||
# case 1: test read all data
|
||||
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False)
|
||||
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False)
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["label"], item2["label"])
|
||||
num_iter += 1
|
||||
assert num_iter == 12
|
||||
|
||||
# case 2: test decode
|
||||
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True, shuffle=False)
|
||||
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True, shuffle=False)
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["label"], item2["label"])
|
||||
num_iter += 1
|
||||
assert num_iter == 12
|
||||
|
||||
# case 3: test num_samples
|
||||
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=4)
|
||||
num_iter = 0
|
||||
for _ in all_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 4
|
||||
|
||||
# case 4: test repeat
|
||||
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=4)
|
||||
all_data = all_data.repeat(2)
|
||||
num_iter = 0
|
||||
for _ in all_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 8
|
||||
|
||||
# case 5: test get_dataset_size, resize and batch
|
||||
all_data = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_samples=8)
|
||||
all_data = all_data.map(operations=[vision.Decode(), vision.Resize((256, 256))], input_columns=["image"],
|
||||
num_parallel_workers=1)
|
||||
|
||||
assert all_data.get_dataset_size() == 8
|
||||
assert all_data.get_batch_size() == 1
|
||||
# drop_remainder is default to be False
|
||||
all_data = all_data.batch(batch_size=3)
|
||||
assert all_data.get_batch_size() == 3
|
||||
assert all_data.get_dataset_size() == 3
|
||||
|
||||
num_iter = 0
|
||||
for _ in all_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_rendered_sst2_decode():
|
||||
"""
|
||||
Feature: RenderedSST2Dataset
|
||||
Description: Validate RenderedSST2Dataset with decode
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Validate RenderedSST2Dataset with decode")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
data1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, decode=True)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_rendered_sst2_sequential_sampler():
|
||||
"""
|
||||
Feature: RenderedSST2Dataset
|
||||
Description: Test RenderedSST2Dataset with SequentialSampler
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test RenderedSST2Dataset Op with SequentialSampler")
|
||||
num_samples = 4
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
all_data_1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=sampler)
|
||||
all_data_2 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_samples=num_samples)
|
||||
label_list_1, label_list_2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
|
||||
all_data_2.create_dict_iterator(num_epochs=1)):
|
||||
label_list_1.append(item1["label"].asnumpy())
|
||||
label_list_2.append(item2["label"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(label_list_1, label_list_2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_rendered_sst2_random_sampler():
|
||||
"""
|
||||
Feature: RenderedSST2Dataset
|
||||
Description: Test RenderedSST2Dataset with RandomSampler
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test RenderedSST2Dataset Op with RandomSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.RandomSampler()
|
||||
data1 = ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_rendered_sst2_exception():
|
||||
"""
|
||||
Feature: RenderedSST2Dataset
|
||||
Description: Test error cases for RenderedSST2Dataset
|
||||
Expectation: Throw correct error and message
|
||||
"""
|
||||
logger.info("Test error cases for RenderedSST2Dataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, sampler=ds.SequentialSampler(1))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, sampler=ds.SequentialSampler(1), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=5, shard_id=-1)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=5, shard_id=5)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.RenderedSST2Dataset(IMAGE_DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
error_msg_8 = "does not exist or is not a directory or permission denied!"
|
||||
with pytest.raises(ValueError, match=error_msg_8):
|
||||
all_data = ds.RenderedSST2Dataset(WRONG_DIR)
|
||||
for _ in all_data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_rendered_sst2_basic()
|
||||
test_rendered_sst2_decode()
|
||||
test_rendered_sst2_sequential_sampler()
|
||||
test_rendered_sst2_random_sampler()
|
||||
test_rendered_sst2_exception()
|