forked from mindspore-Ecosystem/mindspore
!43451 [feat][assistant][I5EWHX] add new data operator SUN397Dataset
Merge pull request !43451 from zhixinaa/SUN397
This commit is contained in:
commit
68359949ad
|
@ -0,0 +1,114 @@
|
|||
mindspore.dataset.SUN397Dataset
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.dataset.SUN397Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
|
||||
|
||||
读取和解析SUN397数据集的源文件构建数据集。
|
||||
|
||||
生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型是uint8。 `label` 列的数据类型是uint32。
|
||||
|
||||
参数:
|
||||
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
|
||||
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值:None,读取全部样本图片。
|
||||
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
|
||||
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值:None,下表中会展示不同参数配置的预期行为。
|
||||
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值:False,不解码。
|
||||
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值:None,下表中会展示不同配置的预期行为。
|
||||
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **ValueError** - `shard_id` 参数错误,小于0或者大于等于 `num_shards` 。
|
||||
|
||||
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
|
||||
|
||||
.. list-table:: 配置 `sampler` 和 `shuffle` 的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 参数 `sampler`
|
||||
- 参数 `shuffle`
|
||||
- 预期数据顺序
|
||||
* - None
|
||||
- None
|
||||
- 随机排列
|
||||
* - None
|
||||
- True
|
||||
- 随机排列
|
||||
* - None
|
||||
- False
|
||||
- 顺序排列
|
||||
* - `sampler` 实例
|
||||
- None
|
||||
- 由 `sampler` 行为定义的顺序
|
||||
* - `sampler` 实例
|
||||
- True
|
||||
- 不允许
|
||||
* - `sampler` 实例
|
||||
- False
|
||||
- 不允许
|
||||
|
||||
**关于SUN397数据集:**
|
||||
|
||||
SUN397是一个用于场景识别的数据集,包括397个类别,有108,754张图像。不同类别的图像数量不同,但每个类别至少有100张。
|
||||
图片为jpg、png或gif格式。
|
||||
|
||||
以下是原始SUN397数据集结构。
|
||||
可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── sun397_dataset_directory
|
||||
├── ClassName.txt
|
||||
├── README.txt
|
||||
├── a
|
||||
│ ├── abbey
|
||||
│ │ ├── sun_aaaulhwrhqgejnyt.jpg
|
||||
│ │ ├── sun_aacphuqehdodwawg.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── apartment_building
|
||||
│ │ └── outdoor
|
||||
│ │ ├── sun_aamyhslnsnomjzue.jpg
|
||||
│ │ ├── sun_abbjzfrsalhqivis.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── ...
|
||||
├── b
|
||||
│ ├── badlands
|
||||
│ │ ├── sun_aabtemlmesogqbbp.jpg
|
||||
│ │ ├── sun_afbsfeexggdhzshd.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── balcony
|
||||
│ │ ├── exterior
|
||||
│ │ │ ├── sun_aaxzaiuznwquburq.jpg
|
||||
│ │ │ ├── sun_baajuldidvlcyzhv.jpg
|
||||
│ │ │ ├── ...
|
||||
│ │ └── interior
|
||||
│ │ ├── sun_babkzjntjfarengi.jpg
|
||||
│ │ ├── sun_bagjvjynskmonnbv.jpg
|
||||
│ │ ├── ...
|
||||
│ └── ...
|
||||
├── ...
|
||||
|
||||
**引用:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{xiao2010sun,
|
||||
title = {Sun database: Large-scale scene recognition from abbey to zoo},
|
||||
author = {Xiao, Jianxiong and Hays, James and Ehinger, Krista A and Oliva, Aude and Torralba, Antonio},
|
||||
booktitle = {2010 IEEE computer society conference on computer vision and pattern recognition},
|
||||
pages = {3485--3492},
|
||||
year = {2010},
|
||||
organization = {IEEE}
|
||||
}
|
||||
|
||||
|
||||
.. include:: mindspore.dataset.api_list_vision.rst
|
|
@ -124,6 +124,7 @@ mindspore.dataset
|
|||
mindspore.dataset.SBUDataset
|
||||
mindspore.dataset.SemeionDataset
|
||||
mindspore.dataset.STL10Dataset
|
||||
mindspore.dataset.SUN397Dataset
|
||||
mindspore.dataset.SVHNDataset
|
||||
mindspore.dataset.USPSDataset
|
||||
mindspore.dataset.VOCDataset
|
||||
|
|
|
@ -36,6 +36,7 @@ Vision
|
|||
mindspore.dataset.SBUDataset
|
||||
mindspore.dataset.SemeionDataset
|
||||
mindspore.dataset.STL10Dataset
|
||||
mindspore.dataset.SUN397Dataset
|
||||
mindspore.dataset.SVHNDataset
|
||||
mindspore.dataset.USPSDataset
|
||||
mindspore.dataset.VOCDataset
|
||||
|
|
|
@ -122,6 +122,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
|
@ -1904,6 +1905,28 @@ STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vect
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -60,6 +60,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
|
||||
|
@ -707,6 +708,16 @@ PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SUN397Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SUN397Node, DatasetNode, std::shared_ptr<SUN397Node>>(*m, "SUN397Node",
|
||||
"to create a SUN397Node")
|
||||
.def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
|
||||
auto sun397 = std::make_shared<SUN397Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(sun397->ValidateParams());
|
||||
return sun397;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
|
||||
"to create a TedliumNode")
|
||||
|
|
|
@ -50,6 +50,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
speech_commands_op.cc
|
||||
squad_op.cc
|
||||
stl10_op.cc
|
||||
sun397_op.cc
|
||||
tedlium_op.cc
|
||||
text_file_op.cc
|
||||
udpos_op.cc
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sun397_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
|
||||
#include "include/common/debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr char kCategoriesMeta[] = "ClassName.txt";
|
||||
|
||||
SUN397Op::SUN397Op(const std::string &file_dir, bool decode, int32_t num_workers, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(file_dir),
|
||||
decode_(decode),
|
||||
buf_cnt_(0),
|
||||
categorie2id_({}),
|
||||
image_path_label_pairs_({}),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
Status SUN397Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
auto file_path = image_path_label_pairs_[row_id].first;
|
||||
auto label_num = image_path_label_pairs_[row_id].second;
|
||||
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_num, &label));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(file_path, &image));
|
||||
|
||||
if (decode_) {
|
||||
Status rc = Decode(image, &image);
|
||||
if (rc.IsError()) {
|
||||
std::string err = "Invalid image, " + file_path + " decode failed, the image is broken or permission denied.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
trow->setPath({file_path, std::string("")});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SUN397Op::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op.
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff.
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nSUN397 directory: " << folder_path_
|
||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp.
|
||||
Status SUN397Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || image_path_label_pairs_.empty()) {
|
||||
if (image_path_label_pairs_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("No image found in dataset. Check if image was read successfully.");
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place,"
|
||||
"it must be empty before using GetClassIds.");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < image_path_label_pairs_.size(); ++i) {
|
||||
(*cls_ids)[image_path_label_pairs_[i].second].push_back(i);
|
||||
}
|
||||
for (auto &pair : (*cls_ids)) {
|
||||
pair.second.shrink_to_fit();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Op::GetFileContent(const std::string &info_file, std::string *ans) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ans);
|
||||
std::ifstream reader;
|
||||
reader.open(info_file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
!reader.fail(), "Invalid file, failed to open " + info_file + ": SUN397 file is damaged or permission denied.");
|
||||
reader.seekg(0, std::ios::end);
|
||||
std::size_t size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
std::string buffer(size, 0);
|
||||
reader.read(&buffer[0], size);
|
||||
buffer[size] = '\0';
|
||||
reader.close();
|
||||
|
||||
// remove \n character in the buffer.
|
||||
std::string so(buffer);
|
||||
std::regex pattern("([\\s\\n]+)");
|
||||
std::string fmt = " ";
|
||||
std::string s = std::regex_replace(so, pattern, fmt);
|
||||
|
||||
// remove the head and tail whiteblanks of the s.
|
||||
s.erase(0, s.find_first_not_of(" "));
|
||||
s.erase(s.find_last_not_of(" ") + 1);
|
||||
// append one whiteblanks to the end of s.
|
||||
s += " ";
|
||||
*ans = s;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Op::LoadCategories(const std::string &category_meta_name) {
|
||||
categorie2id_.clear();
|
||||
std::string s;
|
||||
RETURN_IF_NOT_OK(GetFileContent(category_meta_name, &s));
|
||||
auto get_splited_str = [&s, &category_meta_name](std::size_t pos) {
|
||||
std::string item = s.substr(0, pos);
|
||||
// If pos+1 is equal to the string length, the function returns an empty string.
|
||||
s = s.substr(pos + 1);
|
||||
return item;
|
||||
};
|
||||
|
||||
std::string category;
|
||||
uint32_t label = 0;
|
||||
std::size_t pos = 0;
|
||||
while ((pos = s.find(" ")) != std::string::npos) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pos + 1 <= s.size(), "Invalid data, Reading SUN397 category file failed: " +
|
||||
category_meta_name + ", space characters not found.");
|
||||
category = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!category.empty(), "Invalid data, Reading SUN397 category file failed: " +
|
||||
category_meta_name + ", space characters not found.");
|
||||
categorie2id_.insert({category, label});
|
||||
label++;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Op::PrepareData() {
|
||||
auto real_folder_path = FileUtils::GetRealPath(folder_path_.c_str());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Invalid file path, " + folder_path_ + " does not exist.");
|
||||
|
||||
RETURN_IF_NOT_OK(LoadCategories((Path(real_folder_path.value()) / Path(kCategoriesMeta)).ToString()));
|
||||
image_path_label_pairs_.clear();
|
||||
for (auto c2i : categorie2id_) {
|
||||
std::string folder_name = c2i.first;
|
||||
uint32_t label = c2i.second;
|
||||
|
||||
Path folder(folder_path_ + folder_name);
|
||||
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
|
||||
if (!folder.Exists() || dirItr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, " + folder_name + " does not exist or permission denied.");
|
||||
}
|
||||
std::set<std::string> imgs; // use this for ordering
|
||||
auto dirname_offset = folder.ToString().size();
|
||||
while (dirItr->HasNext()) {
|
||||
Path file = dirItr->Next();
|
||||
if (file.Extension() == ".jpg") {
|
||||
auto file_str = file.ToString();
|
||||
if (file_str.substr(dirname_offset + 1).find("sun_") == 0) {
|
||||
(void)imgs.insert(file_str);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "SUN397Dataset unsupported file found: " << file.ToString()
|
||||
<< ", extension: " << file.Extension() << ".";
|
||||
}
|
||||
}
|
||||
for (const std::string &img : imgs) {
|
||||
image_path_label_pairs_.push_back({img, label});
|
||||
}
|
||||
}
|
||||
num_rows_ = image_path_label_pairs_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_rows_ > 0,
|
||||
"Invalid data, no valid data matching the dataset API SUN397Dataset. Please check dataset API or file path: " +
|
||||
folder_path_ + ".");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Op::CountTotalRows(const std::string &dir, bool decode, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
auto op =
|
||||
std::make_shared<SUN397Op>(dir, decode, num_workers, op_connect_size, std::move(schema), std::move(sampler));
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
|
||||
*count = op->image_path_label_pairs_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Op::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Forward declares.
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
using SUN397LabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;
|
||||
|
||||
class SUN397Op : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] file_dir Dir directory of SUN397Dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] num_workers Num of workers reading images in parallel.
|
||||
/// \param[in] queue_size Connector queue size.
|
||||
/// \param[in] data_schema Schema of data.
|
||||
/// \param[in] sampler Sampler tells SUN397Op what to read.
|
||||
SUN397Op(const std::string &file_dir, bool decode, int32_t num_workers, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SUN397Op() = default;
|
||||
|
||||
/// \brief Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
|
||||
/// \param[in] cls_ids Key label, val all ids for this class.
|
||||
/// \return The status code returned.
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \param[in] dir Path to the PhotoTour directory.
|
||||
/// \param[in] decode Decode jpg format images.
|
||||
/// \param[out] count Output arg that will hold the minimum of the actual dataset
|
||||
/// size and numSamples.
|
||||
/// \return The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, bool decode, int64_t *count);
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "SUN397Op"; }
|
||||
|
||||
private:
|
||||
/// \brief Load a tensor row according to a pair.
|
||||
/// \param[in] row_id Id for this tensor row.
|
||||
/// \param[out] row Image & label read into this tensor row.
|
||||
/// \return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row);
|
||||
|
||||
/// \brief The content in the given file path.
|
||||
/// \param[in] info_file Info file name.
|
||||
/// \param[out] ans Store the content of the info file.
|
||||
/// \return Status The status code returned
|
||||
Status GetFileContent(const std::string &info_file, std::string *ans);
|
||||
|
||||
/// \brief Load the meta information of categories.
|
||||
/// \param[in] category_meta_name Category file name.
|
||||
/// \return Status The status code returned.
|
||||
Status LoadCategories(const std::string &category_meta_name);
|
||||
|
||||
/// \brief Initialize SUN397Op related var, calls the function to walk all files.
|
||||
/// \return Status The status code returned.
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Private function for computing the assignment of the column name map.
|
||||
/// \return Status The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int64_t buf_cnt_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
std::string folder_path_; // directory of image folder
|
||||
const bool decode_;
|
||||
std::map<std::string, uint32_t> categorie2id_;
|
||||
std::vector<std::pair<std::string, uint32_t>> image_path_label_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_
|
|
@ -124,6 +124,7 @@ constexpr char kSogouNewsNode[] = "SogouNewsDataset";
|
|||
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
|
||||
constexpr char kSQuADNode[] = "SQuADDataset";
|
||||
constexpr char kSTL10Node[] = "STL10Dataset";
|
||||
constexpr char kSUN397Node[] = "SUN397Dataset";
|
||||
constexpr char kTedliumNode[] = "TedliumDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
|
|
|
@ -51,6 +51,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
speech_commands_node.cc
|
||||
squad_node.cc
|
||||
stl10_node.cc
|
||||
sun397_node.cc
|
||||
tedlium_node.cc
|
||||
text_file_node.cc
|
||||
tf_record_node.cc
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SUN397Node::SUN397Node(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> SUN397Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<SUN397Node>(dataset_dir_, decode_, sampler, cache_);
|
||||
(void)node->SetNumWorkers(num_workers_);
|
||||
(void)node->SetConnectorQueueSize(connector_que_size_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void SUN397Node::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status SUN397Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SUN397Node", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("SUN397Node", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node_ops);
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<SUN397Op>(dataset_dir_, decode_, num_workers_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node.
|
||||
Status SUN397Node::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size.
|
||||
Status SUN397Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_size);
|
||||
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(SUN397Op::CountTotalRows(dataset_dir_, decode_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SUN397Node::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["connector_queue_size"] = connector_que_size_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["decode"] = decode_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sun397_op.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \class SUN397Node.
|
||||
/// \brief A Dataset derived class to represent SUN397 dataset.
|
||||
class SUN397Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] dataset_dir Dataset directory of SUN397Dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Tells SUN397Op what to read.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
SUN397Node(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SUN397Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kSUN397Node; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard ID within num_shards.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed
|
||||
/// up the process of getting dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size The size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const bool Decode() const { return decode_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
/// \param[in] sampler Specify sampler.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_
|
|
@ -5314,6 +5314,94 @@ inline std::shared_ptr<STL10Dataset> DATASET_API STL10(const std::string &datase
|
|||
return std::make_shared<STL10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class SUN397Dataset.
|
||||
/// \brief A source dataset that reads and parses SUN397 dataset.
|
||||
class DATASET_API SUN397Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of SUN397Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of SUN397Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of SUN397Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of SUN397Dataset.
|
||||
~SUN397Dataset() override = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a SUN397Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading. Default: true.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples
|
||||
/// be used to randomly iterate the entire dataset. Default: RandomSampler().
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the current SUN397Dataset.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and MindData object */
|
||||
/// std::string folder_path = "/path/to/sun397_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = SUN397(folder_path, false, std::make_shared<RandomSampler>(false, 5));
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In SUN397 dataset dataset, each dictionary has keys "image" and "label" */
|
||||
/// auto image = row["image"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<SUN397Dataset> DATASET_API
|
||||
SUN397(const std::string &dataset_dir, bool decode = true,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a SUN397Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the current SUN397Dataset.
|
||||
inline std::shared_ptr<SUN397Dataset> DATASET_API SUN397(const std::string &dataset_dir, bool decode,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a SUN397Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the current SUN397Dataset.
|
||||
inline std::shared_ptr<SUN397Dataset> DATASET_API SUN397(const std::string &dataset_dir, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class TedliumDataset
|
||||
/// \brief A source dataset for reading and parsing tedlium dataset.
|
||||
class DATASET_API TedliumDataset : public Dataset {
|
||||
|
|
|
@ -69,6 +69,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class SemeionDataset;
|
||||
friend class SpeechCommandsDataset;
|
||||
friend class STL10Dataset;
|
||||
friend class SUN397Dataset;
|
||||
friend class TedliumDataset;
|
||||
friend class TextFileDataset;
|
||||
friend class TFRecordDataset;
|
||||
|
|
|
@ -67,6 +67,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"SBUDataset", # Vision
|
||||
"SemeionDataset", # Vision
|
||||
"STL10Dataset", # Vision
|
||||
"SUN397Dataset", # Vision
|
||||
"SVHNDataset", # Vision
|
||||
"USPSDataset", # Vision
|
||||
"VOCDataset", # Vision
|
||||
|
|
|
@ -38,7 +38,7 @@ from .validators import check_caltech101_dataset, check_caltech256_dataset, chec
|
|||
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
|
||||
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
|
||||
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
|
||||
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset
|
||||
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset, check_sun397_dataset
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
|
||||
|
@ -4437,6 +4437,150 @@ class STL10Dataset(MappableDataset, VisionBaseDataset):
|
|||
return cde.STL10Node(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class SUN397Dataset(MappableDataset, VisionBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses SUN397 dataset.
|
||||
|
||||
The generated dataset has two columns: :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is of the uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
Default: None, all images.
|
||||
num_parallel_workers (int, optional): Number of workers to read the data.
|
||||
Default: None, number set in the mindspore.dataset.config.
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
|
||||
Default: None, expected order behavior shown in the table below.
|
||||
decode (bool, optional): Whether or not to decode the images after reading. Default: False.
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset. Default: None, expected order behavior shown in the table below.
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into. When this argument is specified, `num_samples` reflects
|
||||
the maximum sample number of per shard. Default: None.
|
||||
shard_id (int, optional): The shard ID within `num_shards` . This
|
||||
argument can only be specified when `num_shards` is also specified. Default: None.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
|
||||
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
|
||||
Default: None, which means no cache is used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `dataset_dir` does not contain data files.
|
||||
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
|
||||
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
|
||||
RuntimeError: If `num_shards` is specified but `shard_id` is None.
|
||||
RuntimeError: If `shard_id` is specified but `num_shards` is None.
|
||||
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
|
||||
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> sun397_dataset_dir = "/path/to/sun397_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Read all samples (image files) in sun397_dataset_dir with 8 threads
|
||||
>>> dataset = ds.SUN397Dataset(dataset_dir=sun397_dataset_dir, num_parallel_workers=8)
|
||||
|
||||
About SUN397Dataset:
|
||||
|
||||
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 397 categories with
|
||||
108,754 images.The number of images varies across categories, but there are at least 100 images per category.
|
||||
Images are in jpg, png, or gif format.
|
||||
|
||||
Here is the original SUN397 dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── sun397_dataset_directory
|
||||
├── ClassName.txt
|
||||
├── README.txt
|
||||
├── a
|
||||
│ ├── abbey
|
||||
│ │ ├── sun_aaaulhwrhqgejnyt.jpg
|
||||
│ │ ├── sun_aacphuqehdodwawg.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── apartment_building
|
||||
│ │ └── outdoor
|
||||
│ │ ├── sun_aamyhslnsnomjzue.jpg
|
||||
│ │ ├── sun_abbjzfrsalhqivis.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── ...
|
||||
├── b
|
||||
│ ├── badlands
|
||||
│ │ ├── sun_aabtemlmesogqbbp.jpg
|
||||
│ │ ├── sun_afbsfeexggdhzshd.jpg
|
||||
│ │ ├── ...
|
||||
│ ├── balcony
|
||||
│ │ ├── exterior
|
||||
│ │ │ ├── sun_aaxzaiuznwquburq.jpg
|
||||
│ │ │ ├── sun_baajuldidvlcyzhv.jpg
|
||||
│ │ │ ├── ...
|
||||
│ │ └── interior
|
||||
│ │ ├── sun_babkzjntjfarengi.jpg
|
||||
│ │ ├── sun_bagjvjynskmonnbv.jpg
|
||||
│ │ ├── ...
|
||||
│ └── ...
|
||||
├── ...
|
||||
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{xiao2010sun,
|
||||
title = {Sun database: Large-scale scene recognition from abbey to zoo},
|
||||
author = {Xiao, Jianxiong and Hays, James and Ehinger, Krista A and Oliva, Aude and Torralba, Antonio},
|
||||
booktitle = {2010 IEEE computer society conference on computer vision and pattern recognition},
|
||||
pages = {3485--3492},
|
||||
year = {2010},
|
||||
organization = {IEEE}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_sun397_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False,
|
||||
sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.SUN397Node(self.dataset_dir, self.decode, self.sampler)
|
||||
|
||||
|
||||
class _SVHNDataset:
|
||||
"""
|
||||
Mainly for loading SVHN Dataset, and return two rows each time.
|
||||
|
|
|
@ -2865,6 +2865,31 @@ def check_stl10_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_sun397_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(SUN397Dataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_yahoo_answers_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original YahooAnswers Dataset."""
|
||||
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test basic usage of SUN397Dataset
|
||||
/// Expectation: The dataset is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSUN397TrainStandardDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainStandardDataset.";
|
||||
|
||||
// Create a SUN397 Train Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 4);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test usage of SUN397Dataset with pipeline mode
|
||||
/// Expectation: The dataset is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSUN397TrainDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainDatasetWithPipeline.";
|
||||
|
||||
// Create two SUN397 Train Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds1 = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
|
||||
std::shared_ptr<Dataset> ds2 = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds.
|
||||
int32_t repeat_num = 2;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds.
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds.
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 16);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test iterator of SUN397Dataset with only the image column
|
||||
/// Expectation: The dataset is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestSUN397IteratorOneColumn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397IteratorOneColumn.";
|
||||
// Create a SUN397 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 2;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// Only select "image" column and drop others
|
||||
std::vector<std::string> columns = {"image"};
|
||||
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
|
||||
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::vector<mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
std::vector<int64_t> expect_image = {2, 256, 256, 3};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (auto &v : row) {
|
||||
MS_LOG(INFO) << "image shape:" << v.Shape();
|
||||
EXPECT_EQ(expect_image, v.Shape());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test iterator of SUN397Dataset with wrong column
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestSUN397IteratorWrongColumn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397IteratorWrongColumn.";
|
||||
// Create a SUN397 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Pass wrong column name
|
||||
std::vector<std::string> columns = {"digital"};
|
||||
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
|
||||
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test usage of GetDatasetSize of SUN397TrainDataset
|
||||
/// Expectation: Get the correct size
|
||||
TEST_F(MindDataTestPipeline, TestGetSUN397TrainDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSUN397TrainDatasetSize.";
|
||||
|
||||
// Create a SUN397 Train Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test SUN397Dataset Getters method
|
||||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestPipeline, TestSUN397TrainDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainDatasetGetters.";
|
||||
|
||||
// Create a SUN397 Train Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<256,256,3>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test SUN397Dataset with invalid folder path input
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestSUN397DatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397DatasetFail.";
|
||||
|
||||
// Create a SUN397 Dataset.
|
||||
std::shared_ptr<Dataset> ds = SUN397("", true, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid SUN397 input.
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: SUN397Dataset
|
||||
/// Description: Test SUN397Dataset with null sampler
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestSUN397DatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397DatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a SUN397 Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
|
||||
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid SUN397 input, sampler cannot be nullptr.
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
/a/abbey
|
||||
/b/badlands
|
Binary file not shown.
After Width: | Height: | Size: 6.1 KiB |
Binary file not shown.
After Width: | Height: | Size: 4.9 KiB |
Binary file not shown.
After Width: | Height: | Size: 6.1 KiB |
Binary file not shown.
After Width: | Height: | Size: 4.3 KiB |
|
@ -0,0 +1,498 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test SUN397 dataset operators
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testSUN397Data"
|
||||
|
||||
|
||||
def test_sun397_basic():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case basic")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_sun397_num_samples():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with num_samples=10 and num_parallel_workers=2
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case num_samples")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_sun397_num_shards():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with num_shards=2 and shard_id=1
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case numShards")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=1)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_sun397_shard_id():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with num_shards=2 and shard_id=0
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case withShardID")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=0)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_sun397_no_shuffle():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with shuffle=False
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case noShuffle")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_sun397_extra_shuffle():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with shuffle=True
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case extra_shuffle")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, shuffle=True)
|
||||
data1 = data1.shuffle(buffer_size=5)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_sun397_decode():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Test basic usage of SUN397
|
||||
Expectation: The dataset is as expected
|
||||
"""
|
||||
logger.info("Test Case decode")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, decode=True)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_sequential_sampler():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with sampler=ds.SequentialSampler()
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case SequentialSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
# apply dataset operations
|
||||
sampler = ds.SequentialSampler(num_samples=10)
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
result = []
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
result.append(item["label"])
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 4
|
||||
logger.info("Result: {}".format(result))
|
||||
|
||||
|
||||
def test_random_sampler():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with sampler=ds.RandomSampler()
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case RandomSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.RandomSampler()
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_distributed_sampler():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with sampler=ds.DistributedSampler()
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case DistributedSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.DistributedSampler(2, 1)
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_pk_sampler():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with sampler=ds.PKSampler()
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case PKSampler")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
sampler = ds.PKSampler(1)
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_chained_sampler():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with Random and Sequential, with repeat
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
|
||||
|
||||
# Create chained sampler, random and sequential
|
||||
sampler = ds.RandomSampler()
|
||||
child_sampler = ds.SequentialSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
# Create SUN397Dataset with sampler
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
|
||||
|
||||
data1 = data1.repeat(count=3)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 12
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_sun397_zip():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with zip
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test Case zip")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.SUN397Dataset(DATA_DIR, num_samples=10)
|
||||
data2 = ds.SUN397Dataset(DATA_DIR, num_samples=10)
|
||||
|
||||
data1 = data1.repeat(repeat_count)
|
||||
# rename dataset2 for no conflict
|
||||
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
|
||||
data3 = ds.zip((data1, data2))
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_sun397_exception():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with exception
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test sun397 exception")
|
||||
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.SUN397Dataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.SUN397Dataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.SUN397Dataset(DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.SUN397Dataset(DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SUN397Dataset(DATA_DIR, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SUN397Dataset(DATA_DIR, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
|
||||
def test_sun397_exception_map():
|
||||
"""
|
||||
Feature: Test SUN397 Dataset
|
||||
Description: Read data from all file with map operation exception
|
||||
Expectation: The data is processed successfully
|
||||
"""
|
||||
logger.info("Test sun397 exception map")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
def exception_func2(image, label):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.SUN397Dataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.SUN397Dataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func2,
|
||||
input_columns=["image", "label"],
|
||||
output_columns=["image", "label", "label1"],
|
||||
num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.SUN397Dataset(DATA_DIR)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sun397_basic()
|
||||
test_sun397_num_samples()
|
||||
test_sequential_sampler()
|
||||
test_random_sampler()
|
||||
test_distributed_sampler()
|
||||
test_pk_sampler()
|
||||
test_sun397_num_shards()
|
||||
test_sun397_shard_id()
|
||||
test_sun397_no_shuffle()
|
||||
test_sun397_extra_shuffle()
|
||||
test_sun397_decode()
|
||||
test_sun397_zip()
|
||||
test_sun397_exception()
|
||||
test_sun397_exception_map()
|
Loading…
Reference in New Issue