!43451 [feat][assistant][I5EWHX] add new data operator SUN397Dataset

Merge pull request !43451 from zhixinaa/SUN397
This commit is contained in:
i-robot 2022-11-29 03:33:30 +00:00 committed by Gitee
commit 68359949ad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
24 changed files with 1749 additions and 1 deletions

View File

@ -0,0 +1,114 @@
mindspore.dataset.SUN397Dataset
===============================
.. py:class:: mindspore.dataset.SUN397Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
读取和解析SUN397数据集的源文件构建数据集。
生成的数据集有两列: `[image, label]``image` 列的数据类型是uint8。 `label` 列的数据类型是uint32。
参数:
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数可以小于数据集总数。默认值None读取全部样本图片。
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值None使用mindspore.dataset.config中配置的线程数。
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值None下表中会展示不同参数配置的预期行为。
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值False不解码。
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值None下表中会展示不同配置的预期行为。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值None。指定此参数后 `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值None不使用缓存。
异常:
- **RuntimeError** - `dataset_dir` 路径下不包含数据文件。
- **RuntimeError** - 同时指定了 `sampler``shuffle` 参数。
- **RuntimeError** - 同时指定了 `sampler``num_shards` 参数或同时指定了 `sampler``shard_id` 参数。
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
- **ValueError** - `shard_id` 参数错误小于0或者大于等于 `num_shards`
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
.. list-table:: 配置 `sampler``shuffle` 的不同组合得到的预期排序结果
:widths: 25 25 50
:header-rows: 1
* - 参数 `sampler`
- 参数 `shuffle`
- 预期数据顺序
* - None
- None
- 随机排列
* - None
- True
- 随机排列
* - None
- False
- 顺序排列
* - `sampler` 实例
- None
- 由 `sampler` 行为定义的顺序
* - `sampler` 实例
- True
- 不允许
* - `sampler` 实例
- False
- 不允许
**关于SUN397数据集**
SUN397是一个用于场景识别的数据集包括397个类别有108,754张图像。不同类别的图像数量不同但每个类别至少有100张。
图片为jpg、png或gif格式。
以下是原始SUN397数据集结构。
可以将数据集文件解压缩到此目录结构中并由MindSpore的API读取。
.. code-block::
.
└── sun397_dataset_directory
├── ClassName.txt
├── README.txt
├── a
│ ├── abbey
│ │ ├── sun_aaaulhwrhqgejnyt.jpg
│ │ ├── sun_aacphuqehdodwawg.jpg
│ │ ├── ...
│ ├── apartment_building
│ │ └── outdoor
│ │ ├── sun_aamyhslnsnomjzue.jpg
│ │ ├── sun_abbjzfrsalhqivis.jpg
│ │ ├── ...
│ ├── ...
├── b
│ ├── badlands
│ │ ├── sun_aabtemlmesogqbbp.jpg
│ │ ├── sun_afbsfeexggdhzshd.jpg
│ │ ├── ...
│ ├── balcony
│ │ ├── exterior
│ │ │ ├── sun_aaxzaiuznwquburq.jpg
│ │ │ ├── sun_baajuldidvlcyzhv.jpg
│ │ │ ├── ...
│ │ └── interior
│ │ ├── sun_babkzjntjfarengi.jpg
│ │ ├── sun_bagjvjynskmonnbv.jpg
│ │ ├── ...
│ └── ...
├── ...
**引用:**
.. code-block::
@inproceedings{xiao2010sun,
title = {Sun database: Large-scale scene recognition from abbey to zoo},
author = {Xiao, Jianxiong and Hays, James and Ehinger, Krista A and Oliva, Aude and Torralba, Antonio},
booktitle = {2010 IEEE computer society conference on computer vision and pattern recognition},
pages = {3485--3492},
year = {2010},
organization = {IEEE}
}
.. include:: mindspore.dataset.api_list_vision.rst

View File

@ -124,6 +124,7 @@ mindspore.dataset
mindspore.dataset.SBUDataset
mindspore.dataset.SemeionDataset
mindspore.dataset.STL10Dataset
mindspore.dataset.SUN397Dataset
mindspore.dataset.SVHNDataset
mindspore.dataset.USPSDataset
mindspore.dataset.VOCDataset

View File

@ -36,6 +36,7 @@ Vision
mindspore.dataset.SBUDataset
mindspore.dataset.SemeionDataset
mindspore.dataset.STL10Dataset
mindspore.dataset.SUN397Dataset
mindspore.dataset.SVHNDataset
mindspore.dataset.USPSDataset
mindspore.dataset.VOCDataset

View File

@ -122,6 +122,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
@ -1904,6 +1905,28 @@ STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vect
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -60,6 +60,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
@ -707,6 +708,16 @@ PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(SUN397Node, 2, ([](const py::module *m) {
(void)py::class_<SUN397Node, DatasetNode, std::shared_ptr<SUN397Node>>(*m, "SUN397Node",
"to create a SUN397Node")
.def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
auto sun397 = std::make_shared<SUN397Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(sun397->ValidateParams());
return sun397;
}));
}));
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
"to create a TedliumNode")

View File

@ -50,6 +50,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
speech_commands_op.cc
squad_op.cc
stl10_op.cc
sun397_op.cc
tedlium_op.cc
text_file_op.cc
udpos_op.cc

View File

@ -0,0 +1,232 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/sun397_op.h"
#include <fstream>
#include <iomanip>
#include <iostream>
#include <regex>
#include <set>
#include "include/common/debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
constexpr char kCategoriesMeta[] = "ClassName.txt";
SUN397Op::SUN397Op(const std::string &file_dir, bool decode, int32_t num_workers, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(file_dir),
decode_(decode),
buf_cnt_(0),
categorie2id_({}),
image_path_label_pairs_({}),
data_schema_(std::move(data_schema)) {}
Status SUN397Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
auto file_path = image_path_label_pairs_[row_id].first;
auto label_num = image_path_label_pairs_[row_id].second;
std::shared_ptr<Tensor> image, label;
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_num, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(file_path, &image));
if (decode_) {
Status rc = Decode(image, &image);
if (rc.IsError()) {
std::string err = "Invalid image, " + file_path + " decode failed, the image is broken or permission denied.";
RETURN_STATUS_UNEXPECTED(err);
}
}
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({file_path, std::string("")});
return Status::OK();
}
void SUN397Op::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nNumber of rows: " << num_rows_ << "\nSUN397 directory: " << folder_path_
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
// Derived from RandomAccessOp.
Status SUN397Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_path_label_pairs_.empty()) {
if (image_path_label_pairs_.empty()) {
RETURN_STATUS_UNEXPECTED("No image found in dataset. Check if image was read successfully.");
} else {
RETURN_STATUS_UNEXPECTED(
"[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place,"
"it must be empty before using GetClassIds.");
}
}
for (size_t i = 0; i < image_path_label_pairs_.size(); ++i) {
(*cls_ids)[image_path_label_pairs_[i].second].push_back(i);
}
for (auto &pair : (*cls_ids)) {
pair.second.shrink_to_fit();
}
return Status::OK();
}
Status SUN397Op::GetFileContent(const std::string &info_file, std::string *ans) {
RETURN_UNEXPECTED_IF_NULL(ans);
std::ifstream reader;
reader.open(info_file);
CHECK_FAIL_RETURN_UNEXPECTED(
!reader.fail(), "Invalid file, failed to open " + info_file + ": SUN397 file is damaged or permission denied.");
reader.seekg(0, std::ios::end);
std::size_t size = reader.tellg();
reader.seekg(0, std::ios::beg);
std::string buffer(size, 0);
reader.read(&buffer[0], size);
buffer[size] = '\0';
reader.close();
// remove \n character in the buffer.
std::string so(buffer);
std::regex pattern("([\\s\\n]+)");
std::string fmt = " ";
std::string s = std::regex_replace(so, pattern, fmt);
// remove the head and tail whiteblanks of the s.
s.erase(0, s.find_first_not_of(" "));
s.erase(s.find_last_not_of(" ") + 1);
// append one whiteblanks to the end of s.
s += " ";
*ans = s;
return Status::OK();
}
Status SUN397Op::LoadCategories(const std::string &category_meta_name) {
categorie2id_.clear();
std::string s;
RETURN_IF_NOT_OK(GetFileContent(category_meta_name, &s));
auto get_splited_str = [&s, &category_meta_name](std::size_t pos) {
std::string item = s.substr(0, pos);
// If pos+1 is equal to the string length, the function returns an empty string.
s = s.substr(pos + 1);
return item;
};
std::string category;
uint32_t label = 0;
std::size_t pos = 0;
while ((pos = s.find(" ")) != std::string::npos) {
CHECK_FAIL_RETURN_UNEXPECTED(pos + 1 <= s.size(), "Invalid data, Reading SUN397 category file failed: " +
category_meta_name + ", space characters not found.");
category = get_splited_str(pos);
CHECK_FAIL_RETURN_UNEXPECTED(!category.empty(), "Invalid data, Reading SUN397 category file failed: " +
category_meta_name + ", space characters not found.");
categorie2id_.insert({category, label});
label++;
}
return Status::OK();
}
Status SUN397Op::PrepareData() {
auto real_folder_path = FileUtils::GetRealPath(folder_path_.c_str());
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Invalid file path, " + folder_path_ + " does not exist.");
RETURN_IF_NOT_OK(LoadCategories((Path(real_folder_path.value()) / Path(kCategoriesMeta)).ToString()));
image_path_label_pairs_.clear();
for (auto c2i : categorie2id_) {
std::string folder_name = c2i.first;
uint32_t label = c2i.second;
Path folder(folder_path_ + folder_name);
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
if (!folder.Exists() || dirItr == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid path, " + folder_name + " does not exist or permission denied.");
}
std::set<std::string> imgs; // use this for ordering
auto dirname_offset = folder.ToString().size();
while (dirItr->HasNext()) {
Path file = dirItr->Next();
if (file.Extension() == ".jpg") {
auto file_str = file.ToString();
if (file_str.substr(dirname_offset + 1).find("sun_") == 0) {
(void)imgs.insert(file_str);
}
} else {
MS_LOG(WARNING) << "SUN397Dataset unsupported file found: " << file.ToString()
<< ", extension: " << file.Extension() << ".";
}
}
for (const std::string &img : imgs) {
image_path_label_pairs_.push_back({img, label});
}
}
num_rows_ = image_path_label_pairs_.size();
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0,
"Invalid data, no valid data matching the dataset API SUN397Dataset. Please check dataset API or file path: " +
folder_path_ + ".");
return Status::OK();
}
Status SUN397Op::CountTotalRows(const std::string &dir, bool decode, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
auto op =
std::make_shared<SUN397Op>(dir, decode, num_workers, op_connect_size, std::move(schema), std::move(sampler));
RETURN_IF_NOT_OK(op->PrepareData());
*count = op->image_path_label_pairs_.size();
return Status::OK();
}
Status SUN397Op::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,116 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// \brief Forward declares.
template <typename T>
class Queue;
using SUN397LabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;
class SUN397Op : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] file_dir Dir directory of SUN397Dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] num_workers Num of workers reading images in parallel.
/// \param[in] queue_size Connector queue size.
/// \param[in] data_schema Schema of data.
/// \param[in] sampler Sampler tells SUN397Op what to read.
SUN397Op(const std::string &file_dir, bool decode, int32_t num_workers, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor.
~SUN397Op() = default;
/// \brief Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
/// \param[in] cls_ids Key label, val all ids for this class.
/// \return The status code returned.
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
/// \brief A print method typically used for debugging.
/// \param[out] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \param[in] dir Path to the PhotoTour directory.
/// \param[in] decode Decode jpg format images.
/// \param[out] count Output arg that will hold the minimum of the actual dataset
/// size and numSamples.
/// \return The status code returned.
static Status CountTotalRows(const std::string &dir, bool decode, int64_t *count);
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "SUN397Op"; }
private:
/// \brief Load a tensor row according to a pair.
/// \param[in] row_id Id for this tensor row.
/// \param[out] row Image & label read into this tensor row.
/// \return Status The status code returned
Status LoadTensorRow(row_id_type row_id, TensorRow *row);
/// \brief The content in the given file path.
/// \param[in] info_file Info file name.
/// \param[out] ans Store the content of the info file.
/// \return Status The status code returned
Status GetFileContent(const std::string &info_file, std::string *ans);
/// \brief Load the meta information of categories.
/// \param[in] category_meta_name Category file name.
/// \return Status The status code returned.
Status LoadCategories(const std::string &category_meta_name);
/// \brief Initialize SUN397Op related var, calls the function to walk all files.
/// \return Status The status code returned.
Status PrepareData() override;
/// \brief Private function for computing the assignment of the column name map.
/// \return Status The status code returned.
Status ComputeColMap() override;
int64_t buf_cnt_;
std::unique_ptr<DataSchema> data_schema_;
std::string folder_path_; // directory of image folder
const bool decode_;
std::map<std::string, uint32_t> categorie2id_;
std::vector<std::pair<std::string, uint32_t>> image_path_label_pairs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SUN397_OP_H_

View File

@ -124,6 +124,7 @@ constexpr char kSogouNewsNode[] = "SogouNewsDataset";
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
constexpr char kSQuADNode[] = "SQuADDataset";
constexpr char kSTL10Node[] = "STL10Dataset";
constexpr char kSUN397Node[] = "SUN397Dataset";
constexpr char kTedliumNode[] = "TedliumDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";

View File

@ -51,6 +51,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
speech_commands_node.cc
squad_node.cc
stl10_node.cc
sun397_node.cc
tedlium_node.cc
text_file_node.cc
tf_record_node.cc

View File

@ -0,0 +1,115 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
SUN397Node::SUN397Node(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
std::shared_ptr<DatasetNode> SUN397Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<SUN397Node>(dataset_dir_, decode_, sampler, cache_);
(void)node->SetNumWorkers(num_workers_);
(void)node->SetConnectorQueueSize(connector_que_size_);
return node;
}
void SUN397Node::Print(std::ostream &out) const { out << Name(); }
Status SUN397Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SUN397Node", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("SUN397Node", sampler_));
return Status::OK();
}
Status SUN397Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_UNEXPECTED_IF_NULL(node_ops);
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<SUN397Op>(dataset_dir_, decode_, num_workers_, connector_que_size_, std::move(schema),
std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node.
Status SUN397Node::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size.
Status SUN397Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(SUN397Op::CountTotalRows(dataset_dir_, decode_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status SUN397Node::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,103 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/sun397_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class SUN397Node.
/// \brief A Dataset derived class to represent SUN397 dataset.
class SUN397Node : public MappableSourceNode {
public:
/// \brief Constructor.
/// \param[in] dataset_dir Dataset directory of SUN397Dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Tells SUN397Op what to read.
/// \param[in] cache Tensor cache to use.
SUN397Node(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~SUN397Node() override = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kSUN397Node; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard ID within num_shards.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed
/// up the process of getting dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const bool Decode() const { return decode_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Specify sampler.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SUN397_NODE_H_

View File

@ -5314,6 +5314,94 @@ inline std::shared_ptr<STL10Dataset> DATASET_API STL10(const std::string &datase
return std::make_shared<STL10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \class SUN397Dataset.
/// \brief A source dataset that reads and parses SUN397 dataset.
class DATASET_API SUN397Dataset : public Dataset {
public:
/// \brief Constructor of SUN397Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of SUN397Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of SUN397Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of SUN397Dataset.
~SUN397Dataset() override = default;
};
/// \brief Function to create a SUN397Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading. Default: true.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples
/// be used to randomly iterate the entire dataset. Default: RandomSampler().
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the current SUN397Dataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/sun397_dataset_directory";
/// std::shared_ptr<Dataset> ds = SUN397(folder_path, false, std::make_shared<RandomSampler>(false, 5));
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In SUN397 dataset dataset, each dictionary has keys "image" and "label" */
/// auto image = row["image"];
/// \endcode
inline std::shared_ptr<SUN397Dataset> DATASET_API
SUN397(const std::string &dataset_dir, bool decode = true,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \brief Function to create a SUN397Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the current SUN397Dataset.
inline std::shared_ptr<SUN397Dataset> DATASET_API SUN397(const std::string &dataset_dir, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \brief Function to create a SUN397Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the current SUN397Dataset.
inline std::shared_ptr<SUN397Dataset> DATASET_API SUN397(const std::string &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SUN397Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \class TedliumDataset
/// \brief A source dataset for reading and parsing tedlium dataset.
class DATASET_API TedliumDataset : public Dataset {

View File

@ -69,6 +69,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
friend class SemeionDataset;
friend class SpeechCommandsDataset;
friend class STL10Dataset;
friend class SUN397Dataset;
friend class TedliumDataset;
friend class TextFileDataset;
friend class TFRecordDataset;

View File

@ -67,6 +67,7 @@ __all__ = ["Caltech101Dataset", # Vision
"SBUDataset", # Vision
"SemeionDataset", # Vision
"STL10Dataset", # Vision
"SUN397Dataset", # Vision
"SVHNDataset", # Vision
"USPSDataset", # Vision
"VOCDataset", # Vision

View File

@ -38,7 +38,7 @@ from .validators import check_caltech101_dataset, check_caltech256_dataset, chec
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset, check_sun397_dataset
from ..core.validator_helpers import replace_none
@ -4437,6 +4437,150 @@ class STL10Dataset(MappableDataset, VisionBaseDataset):
return cde.STL10Node(self.dataset_dir, self.usage, self.sampler)
class SUN397Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses SUN397 dataset.
The generated dataset has two columns: :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of the uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
num_samples (int, optional): The number of images to be included in the dataset.
Default: None, all images.
num_parallel_workers (int, optional): Number of workers to read the data.
Default: None, number set in the mindspore.dataset.config.
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
Default: None, expected order behavior shown in the table below.
decode (bool, optional): Whether or not to decode the images after reading. Default: False.
sampler (Sampler, optional): Object used to choose samples from the
dataset. Default: None, expected order behavior shown in the table below.
num_shards (int, optional): Number of shards that the dataset will be divided
into. When this argument is specified, `num_samples` reflects
the maximum sample number of per shard. Default: None.
shard_id (int, optional): The shard ID within `num_shards` . This
argument can only be specified when `num_shards` is also specified. Default: None.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
Default: None, which means no cache is used.
Raises:
RuntimeError: If `dataset_dir` does not contain data files.
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
RuntimeError: If `num_shards` is specified but `shard_id` is None.
RuntimeError: If `shard_id` is specified but `num_shards` is None.
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
Note:
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> sun397_dataset_dir = "/path/to/sun397_dataset_directory"
>>>
>>> # 1) Read all samples (image files) in sun397_dataset_dir with 8 threads
>>> dataset = ds.SUN397Dataset(dataset_dir=sun397_dataset_dir, num_parallel_workers=8)
About SUN397Dataset:
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 397 categories with
108,754 images.The number of images varies across categories, but there are at least 100 images per category.
Images are in jpg, png, or gif format.
Here is the original SUN397 dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
sun397_dataset_directory
ClassName.txt
README.txt
a
abbey
sun_aaaulhwrhqgejnyt.jpg
sun_aacphuqehdodwawg.jpg
...
apartment_building
outdoor
sun_aamyhslnsnomjzue.jpg
sun_abbjzfrsalhqivis.jpg
...
...
b
badlands
sun_aabtemlmesogqbbp.jpg
sun_afbsfeexggdhzshd.jpg
...
balcony
exterior
sun_aaxzaiuznwquburq.jpg
sun_baajuldidvlcyzhv.jpg
...
interior
sun_babkzjntjfarengi.jpg
sun_bagjvjynskmonnbv.jpg
...
...
...
Citation:
.. code-block::
@inproceedings{xiao2010sun,
title = {Sun database: Large-scale scene recognition from abbey to zoo},
author = {Xiao, Jianxiong and Hays, James and Ehinger, Krista A and Oliva, Aude and Torralba, Antonio},
booktitle = {2010 IEEE computer society conference on computer vision and pattern recognition},
pages = {3485--3492},
year = {2010},
organization = {IEEE}
}
"""
@check_sun397_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.SUN397Node(self.dataset_dir, self.decode, self.sampler)
class _SVHNDataset:
"""
Mainly for loading SVHN Dataset, and return two rows each time.

View File

@ -2865,6 +2865,31 @@ def check_stl10_dataset(method):
return new_method
def check_sun397_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(SUN397Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_yahoo_answers_dataset(method):
"""A wrapper that wraps a parameter checker around the original YahooAnswers Dataset."""

View File

@ -0,0 +1,270 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: SUN397Dataset
/// Description: Test basic usage of SUN397Dataset
/// Expectation: The dataset is processed successfully
TEST_F(MindDataTestPipeline, TestSUN397TrainStandardDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainStandardDataset.";
// Create a SUN397 Train Dataset.
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: SUN397Dataset
/// Description: Test usage of SUN397Dataset with pipeline mode
/// Expectation: The dataset is processed successfully
TEST_F(MindDataTestPipeline, TestSUN397TrainDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainDatasetWithPipeline.";
// Create two SUN397 Train Dataset.
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds1 = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
std::shared_ptr<Dataset> ds2 = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 2;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 16);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: SUN397Dataset
/// Description: Test iterator of SUN397Dataset with only the image column
/// Expectation: The dataset is processed successfully
TEST_F(MindDataTestPipeline, TestSUN397IteratorOneColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397IteratorOneColumn.";
// Create a SUN397 Dataset
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 2;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// Only select "image" column and drop others
std::vector<std::string> columns = {"image"};
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::vector<mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expect_image = {2, 256, 256, 3};
uint64_t i = 0;
while (row.size() != 0) {
for (auto &v : row) {
MS_LOG(INFO) << "image shape:" << v.Shape();
EXPECT_EQ(expect_image, v.Shape());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SUN397Dataset
/// Description: Test iterator of SUN397Dataset with wrong column
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestSUN397IteratorWrongColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397IteratorWrongColumn.";
// Create a SUN397 Dataset
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds, nullptr);
// Pass wrong column name
std::vector<std::string> columns = {"digital"};
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
/// Feature: SUN397Dataset
/// Description: Test usage of GetDatasetSize of SUN397TrainDataset
/// Expectation: Get the correct size
TEST_F(MindDataTestPipeline, TestGetSUN397TrainDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSUN397TrainDatasetSize.";
// Create a SUN397 Train Dataset.
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 4);
}
/// Feature: SUN397Dataset
/// Description: Test SUN397Dataset Getters method
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestPipeline, TestSUN397TrainDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397TrainDatasetGetters.";
// Create a SUN397 Train Dataset.
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 4);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<256,256,3>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 4);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 4);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 4);
}
/// Feature: SUN397Dataset
/// Description: Test SUN397Dataset with invalid folder path input
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestSUN397DatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397DatasetFail.";
// Create a SUN397 Dataset.
std::shared_ptr<Dataset> ds = SUN397("", true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid SUN397 input.
EXPECT_EQ(iter, nullptr);
}
/// Feature: SUN397Dataset
/// Description: Test SUN397Dataset with null sampler
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestSUN397DatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSUN397DatasetWithNullSamplerFail.";
// Create a SUN397 Dataset.
std::string folder_path = datasets_root_path_ + "/testSUN397Data";
std::shared_ptr<Dataset> ds = SUN397(folder_path, true, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid SUN397 input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,2 @@
/a/abbey
/b/badlands

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,498 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test SUN397 dataset operators
"""
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testSUN397Data"
def test_sun397_basic():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file
Expectation: The data is processed successfully
"""
logger.info("Test Case basic")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_sun397_num_samples():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with num_samples=10 and num_parallel_workers=2
Expectation: The data is processed successfully
"""
logger.info("Test Case num_samples")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.SUN397Dataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.SUN397Dataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
def test_sun397_num_shards():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with num_shards=2 and shard_id=1
Expectation: The data is processed successfully
"""
logger.info("Test Case numShards")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=1)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_sun397_shard_id():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with num_shards=2 and shard_id=0
Expectation: The data is processed successfully
"""
logger.info("Test Case withShardID")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=0)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_sun397_no_shuffle():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with shuffle=False
Expectation: The data is processed successfully
"""
logger.info("Test Case noShuffle")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_sun397_extra_shuffle():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with shuffle=True
Expectation: The data is processed successfully
"""
logger.info("Test Case extra_shuffle")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, shuffle=True)
data1 = data1.shuffle(buffer_size=5)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_sun397_decode():
"""
Feature: Test SUN397 Dataset
Description: Test basic usage of SUN397
Expectation: The dataset is as expected
"""
logger.info("Test Case decode")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, decode=True)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_sequential_sampler():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with sampler=ds.SequentialSampler()
Expectation: The data is processed successfully
"""
logger.info("Test Case SequentialSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.SequentialSampler(num_samples=10)
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
result = []
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
result.append(item["label"])
num_iter += 1
assert num_iter == 4
logger.info("Result: {}".format(result))
def test_random_sampler():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with sampler=ds.RandomSampler()
Expectation: The data is processed successfully
"""
logger.info("Test Case RandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.RandomSampler()
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_distributed_sampler():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with sampler=ds.DistributedSampler()
Expectation: The data is processed successfully
"""
logger.info("Test Case DistributedSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.DistributedSampler(2, 1)
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_pk_sampler():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with sampler=ds.PKSampler()
Expectation: The data is processed successfully
"""
logger.info("Test Case PKSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.PKSampler(1)
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_chained_sampler():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with Random and Sequential, with repeat
Expectation: The data is processed successfully
"""
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
# Create chained sampler, random and sequential
sampler = ds.RandomSampler()
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
# Create SUN397Dataset with sampler
data1 = ds.SUN397Dataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(count=3)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 12
# Verify number of iterations
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
def test_sun397_zip():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with zip
Expectation: The data is processed successfully
"""
logger.info("Test Case zip")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.SUN397Dataset(DATA_DIR, num_samples=10)
data2 = ds.SUN397Dataset(DATA_DIR, num_samples=10)
data1 = data1.repeat(repeat_count)
# rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
data3 = ds.zip((data1, data2))
num_iter = 0
# each data is a dictionary
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_sun397_exception():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with exception
Expectation: The data is processed successfully
"""
logger.info("Test sun397 exception")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.SUN397Dataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.SUN397Dataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.SUN397Dataset(DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.SUN397Dataset(DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.SUN397Dataset(DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.SUN397Dataset(DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.SUN397Dataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.SUN397Dataset(DATA_DIR, num_shards=2, shard_id="0")
def test_sun397_exception_map():
"""
Feature: Test SUN397 Dataset
Description: Read data from all file with map operation exception
Expectation: The data is processed successfully
"""
logger.info("Test sun397 exception map")
def exception_func(item):
raise Exception("Error occur!")
def exception_func2(image, label):
raise Exception("Error occur!")
try:
data = ds.SUN397Dataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
try:
data = ds.SUN397Dataset(DATA_DIR)
data = data.map(operations=exception_func2,
input_columns=["image", "label"],
output_columns=["image", "label", "label1"],
num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
try:
data = ds.SUN397Dataset(DATA_DIR)
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
if __name__ == '__main__':
test_sun397_basic()
test_sun397_num_samples()
test_sequential_sampler()
test_random_sampler()
test_distributed_sampler()
test_pk_sampler()
test_sun397_num_shards()
test_sun397_shard_id()
test_sun397_no_shuffle()
test_sun397_extra_shuffle()
test_sun397_decode()
test_sun397_zip()
test_sun397_exception()
test_sun397_exception_map()