[assistant][ops]New operator implementation, include Food101Dataset
|
@ -0,0 +1,102 @@
|
|||
mindspore.dataset.Food101Dataset
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.dataset.Food101Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None)
|
||||
|
||||
读取和解析Food101数据集的源文件构建数据集。
|
||||
|
||||
生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为string。
|
||||
|
||||
参数:
|
||||
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选) - 指定数据集的子集,可取值为 'train'、'test' 或 'all'。
|
||||
取值为'train'时将会读取75,750个训练样本,取值为'test'时将会读取25,250个测试样本,取值为'all'时将会读取全部101,000个样本。默认值:None,读取全部样本图片。
|
||||
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值:None,读取全部样本图片。
|
||||
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
|
||||
- **shuffle** (bool, 可选) - 是否混洗数据集。默认值:None,下表中会展示不同参数配置的预期行为。
|
||||
- **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值:False,不解码。
|
||||
- **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值:None,下表中会展示不同配置的预期行为。
|
||||
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。
|
||||
- **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
|
||||
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
|
||||
- **ValueError** - `shard_id` 参数错误,小于0或者大于等于 `num_shards` 。
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **ValueError** - `usage` 参数取值不为'train'、'test'或'all'。
|
||||
- **ValueError** - `dataset_dir` 指定的文件夹不存在。
|
||||
|
||||
.. note:: 此数据集可以指定参数 `sampler` ,但参数 `sampler` 和参数 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数组合及预期的行为。
|
||||
|
||||
.. list-table:: 配置 `sampler` 和 `shuffle` 的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 参数 `sampler`
|
||||
- 参数 `shuffle`
|
||||
- 预期数据顺序
|
||||
* - None
|
||||
- None
|
||||
- 随机排列
|
||||
* - None
|
||||
- True
|
||||
- 随机排列
|
||||
* - None
|
||||
- False
|
||||
- 顺序排列
|
||||
* - `sampler` 实例
|
||||
- None
|
||||
- 由 `sampler` 行为定义的顺序
|
||||
* - `sampler` 实例
|
||||
- True
|
||||
- 不允许
|
||||
* - `sampler` 实例
|
||||
- False
|
||||
- 不允许
|
||||
|
||||
**关于Food101数据集:**
|
||||
|
||||
Food101是一个具有挑战性的数据集,包含101种食品类别,共101000张图片。每一个类别有250张测试图片和750张训练图片。所有图像都被重新缩放,最大边长为512像素。
|
||||
|
||||
以下为原始Food101数据集的结构,您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── food101_dir
|
||||
├── images
|
||||
│ ├── apple_pie
|
||||
│ │ ├── 1005649.jpg
|
||||
│ │ ├── 1014775.jpg
|
||||
│ │ ├──...
|
||||
│ ├── baby_back_rips
|
||||
│ │ ├── 1005293.jpg
|
||||
│ │ ├── 1007102.jpg
|
||||
│ │ ├──...
|
||||
│ └──...
|
||||
└── meta
|
||||
├── train.txt
|
||||
├── test.txt
|
||||
├── classes.txt
|
||||
├── train.json
|
||||
├── test.json
|
||||
└── train.txt
|
||||
|
||||
**引用:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{bossard14,
|
||||
title = {Food-101 -- Mining Discriminative Components with Random Forests},
|
||||
author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
|
||||
booktitle = {European Conference on Computer Vision},
|
||||
year = {2014}
|
||||
}
|
||||
|
||||
|
||||
.. include:: mindspore.dataset.api_list_vision.rst
|
|
@ -112,6 +112,7 @@ mindspore.dataset
|
|||
mindspore.dataset.FashionMnistDataset
|
||||
mindspore.dataset.FlickrDataset
|
||||
mindspore.dataset.Flowers102Dataset
|
||||
mindspore.dataset.Food101Dataset
|
||||
mindspore.dataset.ImageFolderDataset
|
||||
mindspore.dataset.KMnistDataset
|
||||
mindspore.dataset.ManifestDataset
|
||||
|
|
|
@ -24,6 +24,7 @@ Vision
|
|||
mindspore.dataset.FashionMnistDataset
|
||||
mindspore.dataset.FlickrDataset
|
||||
mindspore.dataset.Flowers102Dataset
|
||||
mindspore.dataset.Food101Dataset
|
||||
mindspore.dataset.ImageFolderDataset
|
||||
mindspore.dataset.KMnistDataset
|
||||
mindspore.dataset.ManifestDataset
|
||||
|
|
|
@ -92,6 +92,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
|
||||
|
@ -1303,6 +1304,28 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
@ -326,6 +327,18 @@ PYBIND_REGISTER(FlickrNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Food101Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Food101Node, DatasetNode, std::shared_ptr<Food101Node>>(*m, "Food101Node",
|
||||
"to create a Food101Node")
|
||||
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const py::handle &sampler) {
|
||||
auto food101 =
|
||||
std::make_shared<Food101Node>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(food101->ValidateParams());
|
||||
return food101;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||
|
|
|
@ -22,6 +22,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
fake_image_op.cc
|
||||
fashion_mnist_op.cc
|
||||
flickr_op.cc
|
||||
food101_op.cc
|
||||
gtzan_op.cc
|
||||
image_folder_op.cc
|
||||
imdb_op.cc
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/food101_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <regex>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Food101Op::Food101Op(const std::string &folder_path, const std::string &usage, int32_t num_workers, int32_t queue_size,
|
||||
bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(folder_path),
|
||||
decode_(decode),
|
||||
usage_(usage),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
Status Food101Op::PrepareData() {
|
||||
auto realpath = FileUtils::GetRealPath(folder_path_.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, Food101 dataset dir: " << folder_path_ << " does not exist.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, Food101 dataset dir: " + folder_path_ + " does not exist.");
|
||||
}
|
||||
|
||||
std::string image_root_path = (Path(realpath.value()) / Path("images")).ToString();
|
||||
std::string train_list_txt_ = (Path(realpath.value()) / Path("meta") / Path("train.txt")).ToString();
|
||||
std::string test_list_txt_ = (Path(realpath.value()) / Path("meta") / Path("test.txt")).ToString();
|
||||
|
||||
Path img_folder(image_root_path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(img_folder.Exists(),
|
||||
"Invalid path, Food101 image path: " + image_root_path + " does not exist.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(img_folder.IsDirectory(),
|
||||
"Invalid path, Food101 image path: " + image_root_path + " is not a folder.");
|
||||
std::shared_ptr<Path::DirIterator> img_folder_itr = Path::DirIterator::OpenDirectory(&img_folder);
|
||||
|
||||
RETURN_UNEXPECTED_IF_NULL(img_folder_itr);
|
||||
int32_t dirname_offset_ = img_folder.ToString().length() + 1;
|
||||
|
||||
while (img_folder_itr->HasNext()) {
|
||||
Path sub_dir = img_folder_itr->Next();
|
||||
if (sub_dir.IsDirectory()) {
|
||||
classes_.insert(sub_dir.ToString().substr(dirname_offset_));
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!classes_.empty(),
|
||||
"Invalid path, no subfolder found under path: " + img_folder.ToString());
|
||||
|
||||
if (usage_ == "test") {
|
||||
RETURN_IF_NOT_OK(GetAllImageList(test_list_txt_));
|
||||
} else if (usage_ == "train") {
|
||||
RETURN_IF_NOT_OK(GetAllImageList(train_list_txt_));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(GetAllImageList(train_list_txt_));
|
||||
RETURN_IF_NOT_OK(GetAllImageList(test_list_txt_));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!all_img_lists_.empty(),
|
||||
"No valid train.txt or test.txt file under path: " + image_root_path);
|
||||
all_img_lists_.shrink_to_fit();
|
||||
num_rows_ = all_img_lists_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Food101Op::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\n";
|
||||
} else {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nFood101 dataset dir: " << folder_path_
|
||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status Food101Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::string img_name = (Path(folder_path_) / Path("images") / Path(all_img_lists_[row_id])).ToString() + ".jpg";
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
std::string label_str;
|
||||
for (auto it : classes_) {
|
||||
if (all_img_lists_[row_id].find(it) != all_img_lists_[row_id].npos) {
|
||||
label_str = it;
|
||||
break;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_str, &label));
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(img_name, &image));
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
trow->setPath({img_name, std::string("")});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Food101Op::ReadImageToTensor(const std::string &image_path, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_UNEXPECTED_IF_NULL(tensor);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(image_path, tensor));
|
||||
if (decode_) {
|
||||
Status rc = Decode(*tensor, tensor);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
rc.IsOk(), "Invalid file, failed to decode image, the image may be broken or permission denied: " + image_path);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get dataset size.
|
||||
Status Food101Op::CountTotalRows(int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
if (all_img_lists_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
}
|
||||
*count = static_cast<int64_t>(all_img_lists_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Food101Op::GetAllImageList(const std::string &file_path) {
|
||||
std::ifstream handle(file_path);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open text:" + file_path +
|
||||
", the file is damaged or permission denied.");
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (getline(handle, line)) {
|
||||
if (!line.empty()) {
|
||||
all_img_lists_.push_back(line);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Food101Op::ComputeColMap() {
|
||||
// Set the column name map (base class field).
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t index = 0; index < data_schema_->NumColumns(); index++) {
|
||||
column_name_id_map_[data_schema_->Column(index).Name()] = index;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class Food101Op : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] folder_path Directory of Food101 dataset.
|
||||
/// \param[in] usage Usage.
|
||||
/// \param[in] num_workers Number of workers reading images in parallel.
|
||||
/// \param[in] queue_size Connector queue size.
|
||||
/// \param[in] decode Whether to decode images.
|
||||
/// \param[in] schema Data schema of Food101 dataset.
|
||||
/// \param[in] sampler Sampler tells Food101 what to read.
|
||||
Food101Op(const std::string &folder_path, const std::string &usage, int32_t num_workers, int32_t queue_size,
|
||||
bool decode, std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Deconstructor.
|
||||
~Food101Op() override = default;
|
||||
|
||||
/// A print method typically used for debugging.
|
||||
/// \param[out] out Out stream.
|
||||
/// \param[in] show_all Whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "Food101Op"; }
|
||||
|
||||
/// Function to count the number of samples in the Food101 dataset.
|
||||
/// \param[in] count Output arg that will hold the actual dataset size.
|
||||
/// \return The status code returned.
|
||||
Status CountTotalRows(int64_t *count);
|
||||
|
||||
private:
|
||||
/// Load a tensor row.
|
||||
/// \param[in] row_id Row id.
|
||||
/// \param[in] row Read all features into this tensor row.
|
||||
/// \return The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
/// \param[in] image_path Path of image data.
|
||||
/// \param[in] tensor Get image tensor.
|
||||
/// \return The status code returned.
|
||||
Status ReadImageToTensor(const std::string &image_path, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
/// Called first when function is called. Get file_name, img_path info from ".txt" files.
|
||||
/// \return Status - The status code returned.
|
||||
Status PrepareData();
|
||||
|
||||
/// Private function for computing the assignment of the column name map.
|
||||
/// \return Status-the status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
/// Private function for getting all the image files;
|
||||
/// \param[in] file_path The path of the dataset.
|
||||
/// \return Status-the status code returned.
|
||||
Status GetAllImageList(const std::string &file_path);
|
||||
|
||||
std::string folder_path_; // directory of Food101 folder.
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::set<std::string> classes_;
|
||||
std::vector<std::string> all_img_lists_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::map<std::string, std::vector<int32_t>> annotation_map_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FOOD101_OP_H_
|
|
@ -95,6 +95,7 @@ constexpr char kEnWik9Node[] = "EnWik9Dataset";
|
|||
constexpr char kFakeImageNode[] = "FakeImageDataset";
|
||||
constexpr char kFashionMnistNode[] = "FashionMnistDataset";
|
||||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kFood101Node[] = "Food101Dataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kGTZANNode[] = "GTZANDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
|
|
|
@ -23,6 +23,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
fake_image_node.cc
|
||||
fashion_mnist_node.cc
|
||||
flickr_node.cc
|
||||
food101_node.cc
|
||||
gtzan_node.cc
|
||||
image_folder_node.cc
|
||||
imdb_node.cc
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/food101_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Food101Node::Food101Node(const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> Food101Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<Food101Node>(dataset_dir_, usage_, decode_, sampler, cache_);
|
||||
(void)node->SetNumWorkers(num_workers_);
|
||||
(void)node->SetConnectorQueueSize(connector_que_size_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void Food101Node::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status Food101Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Food101Dataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Food101Dataset", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Food101Dataset", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Food101Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
RETURN_UNEXPECTED_IF_NULL(node_ops);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
auto op = std::make_shared<Food101Op>(dataset_dir_, usage_, num_workers_, connector_que_size_, decode_,
|
||||
std::move(schema), std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status Food101Node::GetShardId(int32_t *const shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status Food101Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dataset_size);
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(Build(&ops));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build Food101Op.");
|
||||
auto op = std::dynamic_pointer_cast<Food101Op>(ops.front());
|
||||
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Food101Node::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["connector_queue_size"] = connector_que_size_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["decode"] = decode_;
|
||||
args["usage"] = usage_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status Food101Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFood101Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kFood101Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kFood101Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kFood101Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kFood101Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kFood101Node));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<Food101Node>(dataset_dir, usage, decode, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class Food101Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Food101Node(const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Food101Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kFood101Node; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard id of node.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *const shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size The size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json.
|
||||
/// \param[in] json_obj The JSON object to be deserialized.
|
||||
/// \param[out] ds Deserialized dataset.
|
||||
/// \return Status The status code returned.
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FOOD101_NODE_H_
|
|
@ -2763,6 +2763,99 @@ inline std::shared_ptr<FlickrDataset> DATASET_API Flickr(const std::string &data
|
|||
cache);
|
||||
}
|
||||
|
||||
/// \class Food101Dataset
|
||||
/// \brief A source dataset for reading and parsing Food101 dataset.
|
||||
class DATASET_API Food101Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of Food101Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of Food101Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of Food101Dataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of Food101Dataset.
|
||||
~Food101Dataset() override = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a Food101Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all". Default: "all".
|
||||
/// \param[in] decode Decode the images after reading. Default: false.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset. Default: RandomSampler().
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the Food101Dataset.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and MindData object */
|
||||
/// std::string dataset_path = "/path/to/Food101_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = Food101(dataset_path);
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In Food101 dataset, each data dictionary has keys "image" and "label" */
|
||||
/// auto image = row["image"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<Food101Dataset> DATASET_API
|
||||
Food101(const std::string &dataset_dir, const std::string &usage = "all", bool decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a Food101Dataset
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all" .
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the Food101Dataset.
|
||||
inline std::shared_ptr<Food101Dataset> DATASET_API Food101(const std::string &dataset_dir, const std::string &usage,
|
||||
bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a Food101Dataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "test" or "all".
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
|
||||
/// \return Shared pointer to the Food101Dataset.
|
||||
inline std::shared_ptr<Food101Dataset> DATASET_API Food101(const std::string &dataset_dir, const std::string &usage,
|
||||
bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<Food101Dataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class GTZANDataset
|
||||
/// \brief A source dataset for reading and parsing GTZAN dataset.
|
||||
class DATASET_API GTZANDataset : public Dataset {
|
||||
|
|
|
@ -47,6 +47,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class FakeImageDataset;
|
||||
friend class FashionMnistDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class Food101Dataset;
|
||||
friend class GTZANDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class IMDBDataset;
|
||||
|
|
|
@ -50,6 +50,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"FashionMnistDataset", # Vision
|
||||
"FlickrDataset", # Vision
|
||||
"Flowers102Dataset", # Vision
|
||||
"Food101Dataset", # Vision
|
||||
"ImageFolderDataset", # Vision
|
||||
"KITTIDataset", # Vision
|
||||
"KMnistDataset", # Vision
|
||||
|
|
|
@ -32,14 +32,13 @@ import mindspore._c_dataengine as cde
|
|||
|
||||
from .datasets import VisionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema
|
||||
from .datasets_user_defined import GeneratorDataset
|
||||
from .validators import check_imagefolderdataset, check_kittidataset,\
|
||||
check_mnist_cifar_dataset, check_manifestdataset, check_vocdataset, check_cocodataset, \
|
||||
check_celebadataset, check_flickr_dataset, check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, \
|
||||
check_usps_dataset, check_div2k_dataset, check_random_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
|
||||
check_photo_tour_dataset, check_svhn_dataset, check_stl10_dataset, check_semeion_dataset, \
|
||||
check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset, check_lfw_dataset, \
|
||||
check_lsun_dataset, check_omniglotdataset
|
||||
from .validators import check_caltech101_dataset, check_caltech256_dataset, check_celebadataset, \
|
||||
check_cityscapes_dataset, check_cocodataset, check_div2k_dataset, check_emnist_dataset, check_fake_image_dataset, \
|
||||
check_flickr_dataset, check_flowers102dataset, check_food101_dataset, check_imagefolderdataset, \
|
||||
check_kittidataset, check_lfw_dataset, check_lsun_dataset, check_manifestdataset, check_mnist_cifar_dataset, \
|
||||
check_omniglotdataset, check_photo_tour_dataset, check_places365_dataset, check_qmnist_dataset, \
|
||||
check_random_dataset, check_sb_dataset, check_sbu_dataset, check_semeion_dataset, check_stl10_dataset, \
|
||||
check_svhn_dataset, check_usps_dataset, check_vocdataset, check_wider_face_dataset
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
|
||||
|
@ -2190,6 +2189,140 @@ class Flowers102Dataset(GeneratorDataset):
|
|||
return class_dict
|
||||
|
||||
|
||||
class Food101Dataset(MappableDataset, VisionBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses Food101 dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, label]` .
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is of the string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be 'train', 'test', or 'all'. 'train' will read
|
||||
from 75,750 samples, 'test' will read from 25,250 samples, and 'all' will read all 'train'
|
||||
and 'test' samples. Default: None, will be set to 'all'.
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
Default: None, will read all images.
|
||||
num_parallel_workers (int, optional): Number of workers to read the data.
|
||||
Default: None, number set in the mindspore.dataset.config.
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
|
||||
Default: None, expected order behavior shown in the table below.
|
||||
decode (bool, optional): Decode the images after reading. Default: False.
|
||||
sampler (Sampler, optional): Object used to choose samples from the dataset.
|
||||
Default: None, expected order behavior shown in the table below.
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into. When this argument
|
||||
is specified, `num_samples` reflects the maximum sample number of per shard. Default: None.
|
||||
shard_id (int, optional): The shard ID within `num_shards` . This argument can only be specified
|
||||
when `num_shards` is also specified. Default: None.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
|
||||
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
|
||||
Default: None, which means no cache is used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `dataset_dir` does not contain data files.
|
||||
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
|
||||
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
|
||||
RuntimeError: If `num_shards` is specified but `shard_id` is None.
|
||||
RuntimeError: If `shard_id` is specified but `num_shards` is None.
|
||||
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
|
||||
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
|
||||
ValueError: If the value of `usage` is not 'train', 'test', or 'all'.
|
||||
ValueError: If `dataset_dir` is not exist.
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler` . `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> food101_dataset_dir = "/path/to/food101_dataset_directory"
|
||||
>>>
|
||||
>>> # Read 3 samples from Food101 dataset
|
||||
>>> dataset = ds.Food101Dataset(dataset_dir=food101_dataset_dir, num_samples=3)
|
||||
|
||||
About Food101 dataset:
|
||||
|
||||
The Food101 is a challenging dataset of 101 food categories, with 101,000 images.
|
||||
There are 250 test imgaes and 750 training images in each class. All images were rescaled
|
||||
to have a maximum side length of 512 pixels.
|
||||
|
||||
The following is the original Food101 dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── food101_dir
|
||||
├── images
|
||||
│ ├── apple_pie
|
||||
│ │ ├── 1005649.jpg
|
||||
│ │ ├── 1014775.jpg
|
||||
│ │ ├──...
|
||||
│ ├── baby_back_rips
|
||||
│ │ ├── 1005293.jpg
|
||||
│ │ ├── 1007102.jpg
|
||||
│ │ ├──...
|
||||
│ └──...
|
||||
└── meta
|
||||
├── train.txt
|
||||
├── test.txt
|
||||
├── classes.txt
|
||||
├── train.json
|
||||
├── test.json
|
||||
└── train.txt
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{bossard14,
|
||||
title = {Food-101 -- Mining Discriminative Components with Random Forests},
|
||||
author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
|
||||
booktitle = {European Conference on Computer Vision},
|
||||
year = {2014}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_food101_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.Food101Node(self.dataset_dir, self.usage, self.decode, self.sampler)
|
||||
|
||||
|
||||
class ImageFolderDataset(MappableDataset, VisionBaseDataset):
|
||||
"""
|
||||
A source dataset that reads images from a tree of directories.
|
||||
|
|
|
@ -2328,6 +2328,36 @@ def check_flickr_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_food101_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the Food101Dataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['decode', 'shuffle']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_sb_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
|
||||
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test basic usage of Food101Dataset
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestFood101TestDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDataset.";
|
||||
|
||||
// Create a Food101 Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "test", true, std::make_shared<RandomSampler>(false, 4));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 4);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset in pipeline mode
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestFood101TestDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDatasetWithPipeline.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
|
||||
// Create two Food101 Test Dataset
|
||||
std::shared_ptr<Dataset> ds1 = Food101(folder_path, "test", true, std::make_shared<RandomSampler>(false, 4));
|
||||
std::shared_ptr<Dataset> ds2 = Food101(folder_path, "test", true,std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 8);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset GetDatasetSize
|
||||
/// Expectation: Correct size of dataset
|
||||
TEST_F(MindDataTestPipeline, TestGetFood101TestDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetFood101TestDatasetSize.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
|
||||
// Create a Food101 Test Dataset
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "all");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 8);
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset Getters method
|
||||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestPipeline, TestFood101TestDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101TestDatasetGetters.";
|
||||
|
||||
// Create a Food101 Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "test");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "string");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 4);
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test iterator of Food101Dataset with wrong column
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestFood101IteratorWrongColumn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101IteratorOneColumn.";
|
||||
// Create a Food101 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "all", false, std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Pass wrong column name
|
||||
std::vector<std::string> columns = {"digital"};
|
||||
std::shared_ptr<ProjectDataset> project_ds = ds->Project(columns);
|
||||
std::shared_ptr<Iterator> iter = project_ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset with empty string as the folder path
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestFood101DatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101DatasetFail.";
|
||||
|
||||
// Create a Food101 Dataset
|
||||
std::shared_ptr<Dataset> ds = Food101("", "train", false, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Food101 input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset with invalid usage
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestFood101DatasetWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101DatasetWithInvalidUsageFail.";
|
||||
|
||||
// Create a Food101 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "validation");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Food101 input, validation is not a valid usage
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Food101Dataset
|
||||
/// Description: Test Food101Dataset with null sampler
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestFood101DatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFood101UDatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a Food101 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testFood101Data/";
|
||||
std::shared_ptr<Dataset> ds = Food101(folder_path, "all", false, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Food101 input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 155 KiB |
|
@ -0,0 +1,4 @@
|
|||
class1/0
|
||||
class2/1
|
||||
class3/0
|
||||
class4/0
|
|
@ -0,0 +1,4 @@
|
|||
class1/1
|
||||
class2/0
|
||||
class3/1
|
||||
class4/1
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test Food101 dataset operators
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testFood101Data/"
|
||||
|
||||
|
||||
def test_food101_basic():
|
||||
"""
|
||||
Feature: Food101 dataset
|
||||
Description: Read all files
|
||||
Expectation: Throw number of data in all files
|
||||
"""
|
||||
logger.info("Test Food101Dataset Op")
|
||||
|
||||
# case 1: test loading default usage dataset
|
||||
data1 = ds.Food101Dataset(DATA_DIR)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 8
|
||||
|
||||
# case 2: test num_samples
|
||||
data2 = ds.Food101Dataset(DATA_DIR, num_samples=1)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 1
|
||||
|
||||
# case 3: test repeat
|
||||
data3 = ds.Food101Dataset(DATA_DIR, num_samples=2)
|
||||
data3 = data3.repeat(5)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 10
|
||||
|
||||
|
||||
def test_food101_noshuffle():
|
||||
"""
|
||||
Feature: Food101 dataset
|
||||
Description: Test noshuffle
|
||||
Expectation: Throw number of data in all files
|
||||
"""
|
||||
logger.info("Test Case noShuffle")
|
||||
# define parameters
|
||||
repeat_count = 1
|
||||
|
||||
# apply dataset operations
|
||||
# Note: "all" reads both "train" dataset (2 samples) and "valid" dataset (2 samples)
|
||||
data1 = ds.Food101Dataset(DATA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
# each data is a dictionary
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_food101_usage():
|
||||
"""
|
||||
Feature: Food101 dataset
|
||||
Description: Test Usage
|
||||
Expectation: Throw number of data in all files
|
||||
"""
|
||||
logger.info("Test Food101Dataset usage flag")
|
||||
|
||||
def test_config(usage, food101_path=DATA_DIR):
|
||||
try:
|
||||
data = ds.Food101Dataset(food101_path, usage=usage)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
# test the usage of Food101
|
||||
assert test_config("test") == 4
|
||||
assert test_config("train") == 4
|
||||
assert test_config("all") == 8
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config(
|
||||
"invalid")
|
||||
|
||||
# change to the folder that contains all Food101 files
|
||||
all_food101 = None
|
||||
if all_food101 is not None:
|
||||
assert test_config("test", all_food101) == 25250
|
||||
assert test_config("train", all_food101) == 75750
|
||||
assert test_config("all", all_food101) == 101000
|
||||
assert ds.Food101Dataset(all_food101, usage="test").get_dataset_size() == 25250
|
||||
assert ds.Food101Dataset(all_food101, usage="train").get_dataset_size() == 75750
|
||||
assert ds.Food101Dataset(all_food101, usage="all").get_dataset_size() == 101000
|
||||
|
||||
|
||||
def test_food101_sequential_sampler():
|
||||
"""
|
||||
Feature: Food101 dataset
|
||||
Description: Test SequentialSampler
|
||||
Expectation: Get correct number of data
|
||||
"""
|
||||
num_samples = 1
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.Food101Dataset(DATA_DIR, 'test', sampler=sampler)
|
||||
data2 = ds.Food101Dataset(DATA_DIR, 'test', shuffle=False, num_samples=num_samples)
|
||||
matches_list1, matches_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
|
||||
matches_list1.append(item1["image"].asnumpy())
|
||||
matches_list2.append(item2["image"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(matches_list1, matches_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_food101_pipeline():
|
||||
"""
|
||||
Feature: Pipeline test
|
||||
Description: Read a sample
|
||||
Expectation: The amount of each function are equal
|
||||
"""
|
||||
dataset = ds.Food101Dataset(DATA_DIR, "test", num_samples=1, decode=True)
|
||||
resize_op = vision.Resize((100, 100))
|
||||
dataset = dataset.map(input_columns=["image"], operations=resize_op)
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 1
|
||||
|
||||
|
||||
def test_food101_exception():
|
||||
"""
|
||||
Feature: Food101 dataset
|
||||
Description: Throw error messages when certain errors occur
|
||||
Expectation: Error message
|
||||
"""
|
||||
logger.info("Test error cases for Food101Dataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.Food101Dataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.Food101Dataset(DATA_DIR, sampler=ds.PKSampler(
|
||||
3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.Food101Dataset(DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.Food101Dataset(DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.Food101Dataset(DATA_DIR, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.Food101Dataset(DATA_DIR, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.Food101Dataset(DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.Food101Dataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.Food101Dataset(DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_food101_basic()
|
||||
test_food101_sequential_sampler()
|
||||
test_food101_noshuffle()
|
||||
test_food101_usage()
|
||||
test_food101_pipeline()
|
||||
test_food101_exception()
|