[assistant][ops]New operator implementation, include Food101Dataset

This commit is contained in:
uccInf 2022-09-20 23:04:15 +08:00
parent ffe8b2b0cf
commit 54dc75e6a1
29 changed files with 1387 additions and 8 deletions

View File

@ -0,0 +1,102 @@
mindspore.dataset.Food101Dataset
================================
.. py:class:: mindspore.dataset.Food101Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
读取和解析Food101数据集的源文件构建数据集。
生成的数据集有两列: `[image, label]``image` 列的数据类型为uint8。 `label` 列的数据类型为string。
参数:
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
- **usage** (str, 可选) - 指定数据集的子集,可取值为 'train'、'test' 或 'all'。
取值为'train'时将会读取75,750个训练样本取值为'test'时将会读取25,250个测试样本取值为'all'时将会读取全部101,000个样本。默认值None读取全部样本图片。
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数可以小于数据集总数。默认值None读取全部样本图片。
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值None使用mindspore.dataset.config中配置的线程数。
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值None下表中会展示不同参数配置的预期行为。
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值False不解码。
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值None下表中会展示不同配置的预期行为。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值None。指定此参数后 `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值None不使用缓存。
异常:
- **RuntimeError** - `dataset_dir` 路径下不包含数据文件。
- **RuntimeError** - 同时指定了 `sampler``shuffle` 参数。
- **RuntimeError** - 同时指定了 `sampler``num_shards` 参数或同时指定了 `sampler``shard_id` 参数。
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError** - `shard_id` 参数错误小于0或者大于等于 `num_shards`
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
- **ValueError** - `usage` 参数取值不为'train'、'test'或'all'。
- **ValueError** - `dataset_dir` 指定的文件夹不存在。
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
.. list-table:: 配置 `sampler``shuffle` 的不同组合得到的预期排序结果
:widths: 25 25 50
:header-rows: 1
* - 参数 `sampler`
- 参数 `shuffle`
- 预期数据顺序
* - None
- None
- 随机排列
* - None
- True
- 随机排列
* - None
- False
- 顺序排列
* - `sampler` 实例
- None
- 由 `sampler` 行为定义的顺序
* - `sampler` 实例
- True
- 不允许
* - `sampler` 实例
- False
- 不允许
**关于Food101数据集**
Food101是一个具有挑战性的数据集包含101种食品类别共101000张图片。每一个类别有250张测试图片和750张训练图片。所有图像都被重新缩放最大边长为512像素。
以下为原始Food101数据集的结构您可以将数据集文件解压得到如下的文件结构并通过MindSpore的API进行读取。
.. code-block::
.
└── food101_dir
├── images
│ ├── apple_pie
│ │ ├── 1005649.jpg
│ │ ├── 1014775.jpg
│ │ ├──...
│ ├── baby_back_rips
│ │ ├── 1005293.jpg
│ │ ├── 1007102.jpg
│ │ ├──...
│ └──...
└── meta
├── train.txt
├── test.txt
├── classes.txt
├── train.json
├── test.json
└── train.txt
**引用:**
.. code-block::
@inproceedings{bossard14,
title = {Food-101 -- Mining Discriminative Components with Random Forests},
author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
booktitle = {European Conference on Computer Vision},
year = {2014}
}
.. include:: mindspore.dataset.api_list_vision.rst

View File

@ -112,6 +112,7 @@ mindspore.dataset
mindspore.dataset.FashionMnistDataset
mindspore.dataset.FlickrDataset
mindspore.dataset.Flowers102Dataset
mindspore.dataset.Food101Dataset
mindspore.dataset.ImageFolderDataset
mindspore.dataset.KMnistDataset
mindspore.dataset.ManifestDataset

View File

@ -24,6 +24,7 @@ Vision
mindspore.dataset.FashionMnistDataset
mindspore.dataset.FlickrDataset
mindspore.dataset.Flowers102Dataset
mindspore.dataset.Food101Dataset
mindspore.dataset.ImageFolderDataset
mindspore.dataset.KMnistDataset
mindspore.dataset.ManifestDataset

View File

@ -92,6 +92,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
@ -1303,6 +1304,28 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;

View File

@ -43,6 +43,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
@ -326,6 +327,18 @@ PYBIND_REGISTER(FlickrNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(Food101Node, 2, ([](const py::module *m) {
(void)py::class_<Food101Node, DatasetNode, std::shared_ptr<Food101Node>>(*m, "Food101Node",
"to create a Food101Node")
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
const py::handle &sampler) {
auto food101 =
std::make_shared<Food101Node>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(food101->ValidateParams());
return food101;
}));
}));
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
*m, "GeneratorNode", "to create a GeneratorNode")

View File

@ -22,6 +22,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
fake_image_op.cc
fashion_mnist_op.cc
flickr_op.cc
food101_op.cc
gtzan_op.cc
image_folder_op.cc
imdb_op.cc

View File

@ -0,0 +1,164 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/food101_op.h"
#include <algorithm>
#include <iomanip>
#include <regex>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/path.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
Food101Op::Food101Op(const std::string &folder_path, const std::string &usage, int32_t num_workers, int32_t queue_size,
bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(folder_path),
decode_(decode),
usage_(usage),
data_schema_(std::move(data_schema)) {}
Status Food101Op::PrepareData() {
auto realpath = FileUtils::GetRealPath(folder_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, Food101 dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, Food101 dataset dir: " + folder_path_ + " does not exist.");
}
std::string image_root_path = (Path(realpath.value()) / Path("images")).ToString();
std::string train_list_txt_ = (Path(realpath.value()) / Path("meta") / Path("train.txt")).ToString();
std::string test_list_txt_ = (Path(realpath.value()) / Path("meta") / Path("test.txt")).ToString();
Path img_folder(image_root_path);
CHECK_FAIL_RETURN_UNEXPECTED(img_folder.Exists(),
"Invalid path, Food101 image path: " + image_root_path + " does not exist.");
CHECK_FAIL_RETURN_UNEXPECTED(img_folder.IsDirectory(),
"Invalid path, Food101 image path: " + image_root_path + " is not a folder.");
std::shared_ptr<Path::DirIterator> img_folder_itr = Path::DirIterator::OpenDirectory(&img_folder);
RETURN_UNEXPECTED_IF_NULL(img_folder_itr);
int32_t dirname_offset_ = img_folder.ToString().length() + 1;
while (img_folder_itr->HasNext()) {
Path sub_dir = img_folder_itr->Next();
if (sub_dir.IsDirectory()) {
classes_.insert(sub_dir.ToString().substr(dirname_offset_));
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!classes_.empty(),
"Invalid path, no subfolder found under path: " + img_folder.ToString());
if (usage_ == "test") {
RETURN_IF_NOT_OK(GetAllImageList(test_list_txt_));
} else if (usage_ == "train") {
RETURN_IF_NOT_OK(GetAllImageList(train_list_txt_));
} else {
RETURN_IF_NOT_OK(GetAllImageList(train_list_txt_));
RETURN_IF_NOT_OK(GetAllImageList(test_list_txt_));
}
CHECK_FAIL_RETURN_UNEXPECTED(!all_img_lists_.empty(),
"No valid train.txt or test.txt file under path: " + image_root_path);
all_img_lists_.shrink_to_fit();
num_rows_ = all_img_lists_.size();
return Status::OK();
}
void Food101Op::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
ParallelOp::Print(out, show_all);
out << "\n";
} else {
ParallelOp::Print(out, show_all);
out << "\nNumber of rows: " << num_rows_ << "\nFood101 dataset dir: " << folder_path_
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
Status Food101Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
std::string img_name = (Path(folder_path_) / Path("images") / Path(all_img_lists_[row_id])).ToString() + ".jpg";
std::shared_ptr<Tensor> image, label;
std::string label_str;
for (auto it : classes_) {
if (all_img_lists_[row_id].find(it) != all_img_lists_[row_id].npos) {
label_str = it;
break;
}
}
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_str, &label));
RETURN_IF_NOT_OK(ReadImageToTensor(img_name, &image));
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({img_name, std::string("")});
return Status::OK();
}
Status Food101Op::ReadImageToTensor(const std::string &image_path, std::shared_ptr<Tensor> *tensor) {
RETURN_UNEXPECTED_IF_NULL(tensor);
RETURN_IF_NOT_OK(Tensor::CreateFromFile(image_path, tensor));
if (decode_) {
Status rc = Decode(*tensor, tensor);
CHECK_FAIL_RETURN_UNEXPECTED(
rc.IsOk(), "Invalid file, failed to decode image, the image may be broken or permission denied: " + image_path);
}
return Status::OK();
}
// Get dataset size.
Status Food101Op::CountTotalRows(int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
if (all_img_lists_.size() == 0) {
RETURN_IF_NOT_OK(PrepareData());
}
*count = static_cast<int64_t>(all_img_lists_.size());
return Status::OK();
}
Status Food101Op::GetAllImageList(const std::string &file_path) {
std::ifstream handle(file_path);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open text:" + file_path +
", the file is damaged or permission denied.");
}
std::string line;
while (getline(handle, line)) {
if (!line.empty()) {
all_img_lists_.push_back(line);
}
}
return Status::OK();
}
Status Food101Op::ComputeColMap() {
// Set the column name map (base class field).
if (column_name_id_map_.empty()) {
for (int32_t index = 0; index < data_schema_->NumColumns(); index++) {
column_name_id_map_[data_schema_->Column(index).Name()] = index;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,103 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_
#include <fstream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class Food101Op : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] folder_path Directory of Food101 dataset.
/// \param[in] usage Usage.
/// \param[in] num_workers Number of workers reading images in parallel.
/// \param[in] queue_size Connector queue size.
/// \param[in] decode Whether to decode images.
/// \param[in] schema Data schema of Food101 dataset.
/// \param[in] sampler Sampler tells Food101 what to read.
Food101Op(const std::string &folder_path, const std::string &usage, int32_t num_workers, int32_t queue_size,
bool decode, std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler);
/// \brief Deconstructor.
~Food101Op() override = default;
/// A print method typically used for debugging.
/// \param[out] out Out stream.
/// \param[in] show_all Whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
/// Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "Food101Op"; }
/// Function to count the number of samples in the Food101 dataset.
/// \param[in] count Output arg that will hold the actual dataset size.
/// \return The status code returned.
Status CountTotalRows(int64_t *count);
private:
/// Load a tensor row.
/// \param[in] row_id Row id.
/// \param[in] row Read all features into this tensor row.
/// \return The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
/// \param[in] image_path Path of image data.
/// \param[in] tensor Get image tensor.
/// \return The status code returned.
Status ReadImageToTensor(const std::string &image_path, std::shared_ptr<Tensor> *tensor);
/// Called first when function is called. Get file_name, img_path info from ".txt" files.
/// \return Status - The status code returned.
Status PrepareData();
/// Private function for computing the assignment of the column name map.
/// \return Status-the status code returned.
Status ComputeColMap() override;
/// Private function for getting all the image files;
/// \param[in] file_path The path of the dataset.
/// \return Status-the status code returned.
Status GetAllImageList(const std::string &file_path);
std::string folder_path_; // directory of Food101 folder.
std::string usage_;
bool decode_;
std::unique_ptr<DataSchema> data_schema_;
std::set<std::string> classes_;
std::vector<std::string> all_img_lists_;
std::map<std::string, int32_t> class_index_;
std::map<std::string, std::vector<int32_t>> annotation_map_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_

View File

@ -95,6 +95,7 @@ constexpr char kEnWik9Node[] = "EnWik9Dataset";
constexpr char kFakeImageNode[] = "FakeImageDataset";
constexpr char kFashionMnistNode[] = "FashionMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kFood101Node[] = "Food101Dataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kGTZANNode[] = "GTZANDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";

View File

@ -23,6 +23,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
fake_image_node.cc
fashion_mnist_node.cc
flickr_node.cc
food101_node.cc
gtzan_node.cc
image_folder_node.cc
imdb_node.cc

View File

@ -0,0 +1,150 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/food101_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Food101Node::Food101Node(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> Food101Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Food101Node>(dataset_dir_, usage_, decode_, sampler, cache_);
(void)node->SetNumWorkers(num_workers_);
(void)node->SetConnectorQueueSize(connector_que_size_);
return node;
}
void Food101Node::Print(std::ostream &out) const { out << Name(); }
Status Food101Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Food101Dataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("Food101Dataset", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("Food101Dataset", usage_, {"train", "test", "all"}));
return Status::OK();
}
Status Food101Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
RETURN_UNEXPECTED_IF_NULL(node_ops);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<Food101Op>(dataset_dir_, usage_, num_workers_, connector_que_size_, decode_,
std::move(schema), std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status Food101Node::GetShardId(int32_t *const shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status Food101Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
std::vector<std::shared_ptr<DatasetOp>> ops;
RETURN_IF_NOT_OK(Build(&ops));
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build Food101Op.");
auto op = std::dynamic_pointer_cast<Food101Op>(ops.front());
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status Food101Node::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_;
args["usage"] = usage_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status Food101Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFood101Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kFood101Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kFood101Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kFood101Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kFood101Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kFood101Node));
std::string dataset_dir = json_obj["dataset_dir"];
bool decode = json_obj["decode"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Food101Node>(dataset_dir, usage, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,104 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class Food101Node : public MappableSourceNode {
public:
/// \brief Constructor.
Food101Node(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~Food101Node() override = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kFood101Node; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id of node.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *const shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json.
/// \param[in] json_obj The JSON object to be deserialized.
/// \param[out] ds Deserialized dataset.
/// \return Status The status code returned.
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_

View File

@ -2763,6 +2763,99 @@ inline std::shared_ptr<FlickrDataset> DATASET_API Flickr(const std::string &data
cache);
}
/// \class Food101Dataset
/// \brief A source dataset for reading and parsing Food101 dataset.
class DATASET_API Food101Dataset : public Dataset {
public:
/// \brief Constructor of Food101Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of Food101Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of Food101Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of Food101Dataset.
~Food101Dataset() override = default;
};
/// \brief Function to create a Food101Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all". Default: "all".
/// \param[in] decode Decode the images after reading. Default: false.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset. Default: RandomSampler().
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the Food101Dataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string dataset_path = "/path/to/Food101_dataset_directory";
/// std::shared_ptr<Dataset> ds = Food101(dataset_path);
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In Food101 dataset, each data dictionary has keys "image" and "label" */
/// auto image = row["image"];
/// \endcode
inline std::shared_ptr<Food101Dataset> DATASET_API
Food101(const std::string &dataset_dir, const std::string &usage = "all", bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a Food101Dataset
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all" .
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the Food101Dataset.
inline std::shared_ptr<Food101Dataset> DATASET_API Food101(const std::string &dataset_dir, const std::string &usage,
bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a Food101Dataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the Food101Dataset.
inline std::shared_ptr<Food101Dataset> DATASET_API Food101(const std::string &dataset_dir, const std::string &usage,
bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \class GTZANDataset
/// \brief A source dataset for reading and parsing GTZAN dataset.
class DATASET_API GTZANDataset : public Dataset {

View File

@ -47,6 +47,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
friend class FakeImageDataset;
friend class FashionMnistDataset;
friend class FlickrDataset;
friend class Food101Dataset;
friend class GTZANDataset;
friend class ImageFolderDataset;
friend class IMDBDataset;

View File

@ -50,6 +50,7 @@ __all__ = ["Caltech101Dataset", # Vision
"FashionMnistDataset", # Vision
"FlickrDataset", # Vision
"Flowers102Dataset", # Vision
"Food101Dataset", # Vision
"ImageFolderDataset", # Vision
"KITTIDataset", # Vision
"KMnistDataset", # Vision

View File

@ -32,14 +32,13 @@ import mindspore._c_dataengine as cde
from .datasets import VisionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema
from .datasets_user_defined import GeneratorDataset
from .validators import check_imagefolderdataset, check_kittidataset,\
check_mnist_cifar_dataset, check_manifestdataset, check_vocdataset, check_cocodataset, \
check_celebadataset, check_flickr_dataset, check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, \
check_usps_dataset, check_div2k_dataset, check_random_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_svhn_dataset, check_stl10_dataset, check_semeion_dataset, \
check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset, check_lfw_dataset, \
check_lsun_dataset, check_omniglotdataset
from .validators import check_caltech101_dataset, check_caltech256_dataset, check_celebadataset, \
check_cityscapes_dataset, check_cocodataset, check_div2k_dataset, check_emnist_dataset, check_fake_image_dataset, \
check_flickr_dataset, check_flowers102dataset, check_food101_dataset, check_imagefolderdataset, \
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset
from ..core.validator_helpers import replace_none
@ -2190,6 +2189,140 @@ class Flowers102Dataset(GeneratorDataset):
return class_dict
class Food101Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses Food101 dataset.
The generated dataset has two columns :py:obj:`[image, label]` .
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be 'train', 'test', or 'all'. 'train' will read
from 75,750 samples, 'test' will read from 25,250 samples, and 'all' will read all 'train'
and 'test' samples. Default: None, will be set to 'all'.
num_samples (int, optional): The number of images to be included in the dataset.
Default: None, will read all images.
num_parallel_workers (int, optional): Number of workers to read the data.
Default: None, number set in the mindspore.dataset.config.
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
Default: None, expected order behavior shown in the table below.
decode (bool, optional): Decode the images after reading. Default: False.
sampler (Sampler, optional): Object used to choose samples from the dataset.
Default: None, expected order behavior shown in the table below.
num_shards (int, optional): Number of shards that the dataset will be divided into. When this argument
is specified, `num_samples` reflects the maximum sample number of per shard. Default: None.
shard_id (int, optional): The shard ID within `num_shards` . This argument can only be specified
when `num_shards` is also specified. Default: None.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
Default: None, which means no cache is used.
Raises:
RuntimeError: If `dataset_dir` does not contain data files.
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
RuntimeError: If `num_shards` is specified but `shard_id` is None.
RuntimeError: If `shard_id` is specified but `num_shards` is None.
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
ValueError: If the value of `usage` is not 'train', 'test', or 'all'.
ValueError: If `dataset_dir` is not exist.
Note:
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> food101_dataset_dir = "/path/to/food101_dataset_directory"
>>>
>>> # Read 3 samples from Food101 dataset
>>> dataset = ds.Food101Dataset(dataset_dir=food101_dataset_dir, num_samples=3)
About Food101 dataset:
The Food101 is a challenging dataset of 101 food categories, with 101,000 images.
There are 250 test imgaes and 750 training images in each class. All images were rescaled
to have a maximum side length of 512 pixels.
The following is the original Food101 dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
food101_dir
images
apple_pie
1005649.jpg
1014775.jpg
...
baby_back_rips
1005293.jpg
1007102.jpg
...
...
meta
train.txt
test.txt
classes.txt
train.json
test.json
train.txt
Citation:
.. code-block::
@inproceedings{bossard14,
title = {Food-101 -- Mining Discriminative Components with Random Forests},
author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
booktitle = {European Conference on Computer Vision},
year = {2014}
}
"""
@check_food101_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.Food101Node(self.dataset_dir, self.usage, self.decode, self.sampler)
class ImageFolderDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads images from a tree of directories.

View File

@ -2328,6 +2328,36 @@ def check_flickr_dataset(method):
return new_method
def check_food101_dataset(method):
"""A wrapper that wraps a parameter checker around the Food101Dataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['decode', 'shuffle']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_sb_dataset(method):
"""A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""

View File

@ -0,0 +1,245 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Food101Dataset
/// Description: Test basic usage of Food101Dataset
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestFood101TestDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDataset.";
// Create a Food101 Test Dataset
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
std::shared_ptr<Dataset> ds = Food101(folder_path, "test", true, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset in pipeline mode
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestFood101TestDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDatasetWithPipeline.";
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
// Create two Food101 Test Dataset
std::shared_ptr<Dataset> ds1 = Food101(folder_path, "test", true, std::make_shared<RandomSampler>(false, 4));
std::shared_ptr<Dataset> ds2 = Food101(folder_path, "test", true,std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 8);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset GetDatasetSize
/// Expectation: Correct size of dataset
TEST_F(MindDataTestPipeline, TestGetFood101TestDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetFood101TestDatasetSize.";
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
// Create a Food101 Test Dataset
std::shared_ptr<Dataset> ds = Food101(folder_path, "all");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 8);
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset Getters method
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestPipeline, TestFood101TestDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDatasetGetters.";
// Create a Food101 Test Dataset
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
std::shared_ptr<Dataset> ds = Food101(folder_path, "test");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 4);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "string");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 4);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 4);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 4);
}
/// Feature: Food101Dataset
/// Description: Test iterator of Food101Dataset with wrong column
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestFood101IteratorWrongColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101IteratorOneColumn.";
// Create a Food101 Dataset
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
std::shared_ptr<Dataset> ds = Food101(folder_path, "all", false, std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds, nullptr);
// Pass wrong column name
std::vector<std::string> columns = {"digital"};
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset with empty string as the folder path
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestFood101DatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101DatasetFail.";
// Create a Food101 Dataset
std::shared_ptr<Dataset> ds = Food101("", "train", false, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Food101 input
EXPECT_EQ(iter, nullptr);
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset with invalid usage
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestFood101DatasetWithInvalidUsageFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101DatasetWithInvalidUsageFail.";
// Create a Food101 Dataset
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
std::shared_ptr<Dataset> ds = Food101(folder_path, "validation");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Food101 input, validation is not a valid usage
EXPECT_EQ(iter, nullptr);
}
/// Feature: Food101Dataset
/// Description: Test Food101Dataset with null sampler
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
TEST_F(MindDataTestPipeline, TestFood101DatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101UDatasetWithNullSamplerFail.";
// Create a Food101 Dataset
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
std::shared_ptr<Dataset> ds = Food101(folder_path, "all", false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Food101 input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

View File

@ -0,0 +1,4 @@
class1/0
class2/1
class3/0
class4/0

View File

@ -0,0 +1,4 @@
class1/1
class2/0
class3/1
class4/1

View File

@ -0,0 +1,204 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Food101 dataset operators
"""
import pytest
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.log as logger
DATA_DIR = "../data/dataset/testFood101Data/"
def test_food101_basic():
"""
Feature: Food101 dataset
Description: Read all files
Expectation: Throw number of data in all files
"""
logger.info("Test Food101Dataset Op")
# case 1: test loading default usage dataset
data1 = ds.Food101Dataset(DATA_DIR)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter1 += 1
assert num_iter1 == 8
# case 2: test num_samples
data2 = ds.Food101Dataset(DATA_DIR, num_samples=1)
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter2 += 1
assert num_iter2 == 1
# case 3: test repeat
data3 = ds.Food101Dataset(DATA_DIR, num_samples=2)
data3 = data3.repeat(5)
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter3 += 1
assert num_iter3 == 10
def test_food101_noshuffle():
"""
Feature: Food101 dataset
Description: Test noshuffle
Expectation: Throw number of data in all files
"""
logger.info("Test Case noShuffle")
# define parameters
repeat_count = 1
# apply dataset operations
# Note: "all" reads both "train" dataset (2 samples) and "valid" dataset (2 samples)
data1 = ds.Food101Dataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 8
def test_food101_usage():
"""
Feature: Food101 dataset
Description: Test Usage
Expectation: Throw number of data in all files
"""
logger.info("Test Food101Dataset usage flag")
def test_config(usage, food101_path=DATA_DIR):
try:
data = ds.Food101Dataset(food101_path, usage=usage)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
# test the usage of Food101
assert test_config("test") == 4
assert test_config("train") == 4
assert test_config("all") == 8
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config(
"invalid")
# change to the folder that contains all Food101 files
all_food101 = None
if all_food101 is not None:
assert test_config("test", all_food101) == 25250
assert test_config("train", all_food101) == 75750
assert test_config("all", all_food101) == 101000
assert ds.Food101Dataset(all_food101, usage="test").get_dataset_size() == 25250
assert ds.Food101Dataset(all_food101, usage="train").get_dataset_size() == 75750
assert ds.Food101Dataset(all_food101, usage="all").get_dataset_size() == 101000
def test_food101_sequential_sampler():
"""
Feature: Food101 dataset
Description: Test SequentialSampler
Expectation: Get correct number of data
"""
num_samples = 1
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.Food101Dataset(DATA_DIR, 'test', sampler=sampler)
data2 = ds.Food101Dataset(DATA_DIR, 'test', shuffle=False, num_samples=num_samples)
matches_list1, matches_list2 = [], []
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
matches_list1.append(item1["image"].asnumpy())
matches_list2.append(item2["image"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(matches_list1, matches_list2)
assert num_iter == num_samples
def test_food101_pipeline():
"""
Feature: Pipeline test
Description: Read a sample
Expectation: The amount of each function are equal
"""
dataset = ds.Food101Dataset(DATA_DIR, "test", num_samples=1, decode=True)
resize_op = vision.Resize((100, 100))
dataset = dataset.map(input_columns=["image"], operations=resize_op)
num_iter = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 1
def test_food101_exception():
"""
Feature: Food101 dataset
Description: Throw error messages when certain errors occur
Expectation: Error message
"""
logger.info("Test error cases for Food101Dataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.Food101Dataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.Food101Dataset(DATA_DIR, sampler=ds.PKSampler(
3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.Food101Dataset(DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.Food101Dataset(DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.Food101Dataset(DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.Food101Dataset(DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.Food101Dataset(DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.Food101Dataset(DATA_DIR, num_shards=2, shard_id="0")
if __name__ == '__main__':
test_food101_basic()
test_food101_sequential_sampler()
test_food101_noshuffle()
test_food101_usage()
test_food101_pipeline()
test_food101_exception()