[feat] [assistant] [I3CKEK] add new dataset operator DIV2K
|
@ -95,6 +95,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
@ -982,6 +983,33 @@ CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
|
||||
decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
|
||||
decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
|
||||
decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
@ -137,6 +138,18 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
|
||||
"to create a DIV2KNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale,
|
||||
bool decode, py::handle sampler) {
|
||||
auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
|
||||
toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(div2k->ValidateParams());
|
||||
return div2k;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FlickrNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
|
||||
|
|
|
@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
mappable_leaf_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
cityscapes_op.cc
|
||||
div2k_op.cc
|
||||
flickr_op.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/div2k_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const std::map<std::string, std::string> DatasetPramMap = {{"train_hr", "DIV2K_train_HR"},
|
||||
{"valid_hr", "DIV2K_valid_HR"},
|
||||
{"train_bicubic_x2", "DIV2K_train_LR_bicubic"},
|
||||
{"train_unknown_x2", "DIV2K_train_LR_unknown"},
|
||||
{"valid_bicubic_x2", "DIV2K_valid_LR_bicubic"},
|
||||
{"valid_unknown_x2", "DIV2K_valid_LR_unknown"},
|
||||
{"train_bicubic_x3", "DIV2K_train_LR_bicubic"},
|
||||
{"train_unknown_x3", "DIV2K_train_LR_unknown"},
|
||||
{"valid_bicubic_x3", "DIV2K_valid_LR_bicubic"},
|
||||
{"valid_unknown_x3", "DIV2K_valid_LR_unknown"},
|
||||
{"train_bicubic_x4", "DIV2K_train_LR_bicubic"},
|
||||
{"train_unknown_x4", "DIV2K_train_LR_unknown"},
|
||||
{"valid_bicubic_x4", "DIV2K_valid_LR_bicubic"},
|
||||
{"valid_unknown_x4", "DIV2K_valid_LR_unknown"},
|
||||
{"train_bicubic_x8", "DIV2K_train_LR_x8"},
|
||||
{"valid_bicubic_x8", "DIV2K_valid_LR_x8"},
|
||||
{"train_mild_x4", "DIV2K_train_LR_mild"},
|
||||
{"valid_mild_x4", "DIV2K_valid_LR_mild"},
|
||||
{"train_difficult_x4", "DIV2K_train_LR_difficult"},
|
||||
{"valid_difficult_x4", "DIV2K_valid_LR_difficult"},
|
||||
{"train_wild_x4", "DIV2K_train_LR_wild"},
|
||||
{"valid_wild_x4", "DIV2K_valid_LR_wild"}};
|
||||
|
||||
DIV2KOp::DIV2KOp(int32_t num_workers, const std::string &dataset_dir, const std::string &usage,
|
||||
const std::string &downgrade, int32_t scale, bool decode, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
downgrade_(downgrade),
|
||||
scale_(scale),
|
||||
decode_(decode),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status DIV2KOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&DIV2KOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 3 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseDIV2KData()); // Parse div2k data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (hr_image, lr_image) using 1 ImageLabelPair. 1 function call produces 1 TensorTow.
|
||||
Status DIV2KOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::pair<std::string, std::string> data = image_hr_lr_pairs_[static_cast<size_t>(row_id)];
|
||||
std::shared_ptr<Tensor> hr_image;
|
||||
std::shared_ptr<Tensor> lr_image;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.first, &hr_image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.second, &lr_image));
|
||||
|
||||
if (decode_ == true) {
|
||||
Status hr_rc = Decode(hr_image, &hr_image);
|
||||
if (hr_rc.IsError()) {
|
||||
std::string err = "Invalid data, failed to decode image: " + data.first;
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
|
||||
Status lr_rc = Decode(lr_image, &lr_image);
|
||||
if (lr_rc.IsError()) {
|
||||
std::string err = "Invalid data, failed to decode image: " + data.second;
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
(*trow) = TensorRow(row_id, {std::move(hr_image), std::move(lr_image)});
|
||||
trow->setPath({data.first, data.second});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DIV2KOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows:" << num_rows_ << "\nDIV2K DatasetDir: " << dataset_dir_ << "\nUsage: " << usage_
|
||||
<< "\nScale: " << scale_ << "\nDowngrade: " << downgrade_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status DIV2KOp::ParseDIV2KData() {
|
||||
std::string hr_dir_key;
|
||||
std::string lr_dir_key;
|
||||
|
||||
if (usage_ == "all") {
|
||||
std::vector<std::string> usage_all = {"train", "valid"};
|
||||
for (auto &item : usage_all) {
|
||||
hr_dir_key = item + "_hr";
|
||||
lr_dir_key = item + "_" + downgrade_ + "_x" + std::to_string(scale_);
|
||||
RETURN_IF_NOT_OK(GetDIV2KLRDirRealName(hr_dir_key, lr_dir_key));
|
||||
RETURN_IF_NOT_OK(GetDIV2KDataByUsage());
|
||||
}
|
||||
} else {
|
||||
hr_dir_key = usage_ + "_hr";
|
||||
lr_dir_key = usage_ + "_" + downgrade_ + "_x" + std::to_string(scale_);
|
||||
RETURN_IF_NOT_OK(GetDIV2KLRDirRealName(hr_dir_key, lr_dir_key));
|
||||
RETURN_IF_NOT_OK(GetDIV2KDataByUsage());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KOp::GetDIV2KLRDirRealName(const std::string &hr_dir_key, const std::string &lr_dir_key) {
|
||||
std::set<std::string> downgrade_2017 = {"bicubic", "unknown"};
|
||||
std::set<int32_t> scale_2017 = {2, 3, 4};
|
||||
|
||||
hr_dir_real_name_ = DatasetPramMap.find(hr_dir_key)->second;
|
||||
auto lr_it = DatasetPramMap.find(lr_dir_key);
|
||||
if (lr_it == DatasetPramMap.end()) {
|
||||
std::string out_str = "{\n";
|
||||
std::for_each(DatasetPramMap.begin(), DatasetPramMap.end(),
|
||||
[&out_str](std::pair<std::string, std::string> item) -> void {
|
||||
out_str += ("\t" + item.first + ": " + item.second + ",\n");
|
||||
});
|
||||
out_str += "\n}";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid param, " + lr_dir_key + " not found in DatasetPramMap: \n" + out_str);
|
||||
}
|
||||
|
||||
if (downgrade_2017.find(downgrade_) != downgrade_2017.end() && scale_2017.find(scale_) != scale_2017.end()) {
|
||||
Path ntire_2017(lr_it->second);
|
||||
lr_dir_real_name_ = (ntire_2017 / ("X" + std::to_string(scale_))).ToString();
|
||||
} else {
|
||||
lr_dir_real_name_ = lr_it->second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KOp::GetDIV2KDataByUsage() {
|
||||
const std::string kExtension = ".png";
|
||||
|
||||
auto real_dataset_dir = Common::GetRealPath(dataset_dir_);
|
||||
if (!real_dataset_dir.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_);
|
||||
}
|
||||
|
||||
Path dataset_dir(real_dataset_dir.value());
|
||||
Path hr_images_dir = dataset_dir / hr_dir_real_name_;
|
||||
Path lr_images_dir = dataset_dir / lr_dir_real_name_;
|
||||
|
||||
if (!hr_images_dir.IsDirectory()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, " + hr_images_dir.ToString() + " is an invalid directory path.");
|
||||
}
|
||||
if (!lr_images_dir.IsDirectory()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, " + lr_images_dir.ToString() + " is an invalid directory path.");
|
||||
}
|
||||
auto hr_it = Path::DirIterator::OpenDirectory(&hr_images_dir);
|
||||
if (hr_it == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, failed to open directory: " + hr_images_dir.ToString());
|
||||
}
|
||||
|
||||
std::string image_name;
|
||||
std::string image_id_scale;
|
||||
std::string lr_image_file_path_;
|
||||
std::map<std::string, std::string> image_hr_lr_map_;
|
||||
std::map<std::string, std::string> downgrade_2018 = {{"mild", "m"}, {"difficult", "d"}, {"wild", "w"}};
|
||||
|
||||
while (hr_it->HasNext()) {
|
||||
try {
|
||||
Path hr_img_file = hr_it->Next();
|
||||
if (hr_img_file.Extension() != kExtension) {
|
||||
continue;
|
||||
}
|
||||
|
||||
image_name = hr_img_file.Basename();
|
||||
image_id_scale = image_name.substr(0, image_name.find_last_of(".")) + "x" + std::to_string(scale_);
|
||||
Path hr_image_file_path = hr_images_dir / image_name;
|
||||
auto lr_it = downgrade_2018.find(downgrade_);
|
||||
if (lr_it != downgrade_2018.end()) {
|
||||
lr_image_file_path_ = (lr_images_dir / (image_id_scale + lr_it->second + kExtension)).ToString();
|
||||
} else {
|
||||
lr_image_file_path_ = (lr_images_dir / (image_id_scale + kExtension)).ToString();
|
||||
}
|
||||
|
||||
Path lr_image_file_path(lr_image_file_path_);
|
||||
if (!lr_image_file_path.Exists()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, " + lr_image_file_path.ToString() + " not found.");
|
||||
}
|
||||
|
||||
image_hr_lr_map_[hr_image_file_path.ToString()] = lr_image_file_path.ToString();
|
||||
} catch (const std::exception &err) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid path, failed to load DIV2K Dataset: " + dataset_dir_);
|
||||
}
|
||||
}
|
||||
for (auto item : image_hr_lr_map_) {
|
||||
image_hr_lr_pairs_.emplace_back(std::make_pair(item.first, item.second));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KOp::CountDatasetInfo() {
|
||||
num_rows_ = static_cast<int64_t>(image_hr_lr_pairs_.size());
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API DIV2KDataset. Please check file path or dataset API.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KOp::CountTotalRows(const std::string &dir, const std::string &usage, const std::string &downgrade,
|
||||
int32_t scale, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseDIV2KData()
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto new_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
// build a new unique schema object
|
||||
auto new_schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(ColDescriptor("hr_image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
new_schema->AddColumn(ColDescriptor("lr_image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
std::shared_ptr<DIV2KOp> op = std::make_shared<DIV2KOp>(
|
||||
num_workers, dir, usage, downgrade, scale, false, op_connect_size, std::move(new_schema), std::move(new_sampler));
|
||||
RETURN_IF_NOT_OK(op->ParseDIV2KData());
|
||||
*count = static_cast<int64_t>(op->image_hr_lr_pairs_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KOp::ComputeColMap() {
|
||||
// Set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DIV2KOp : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] int32_t num_workers - num of workers reading images in parallel.
|
||||
/// \param[in] std::string dataset_dir - dir directory of DIV2K dataset.
|
||||
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] std::string downgrade - the mode of downgrade. Acceptable downgrades include "bicubic", "unknown",
|
||||
/// "mild", "difficult" or "wild".
|
||||
/// \param[in] int32_t scale - the scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] bool decode - decode the images after reading.
|
||||
/// \param[in] int32_t queue_size - connector queue size.
|
||||
/// \param[in] DataSchema data_schema - the schema of each column in output data.
|
||||
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read.
|
||||
DIV2KOp(int32_t num_workers, const std::string &dataset_dir, const std::string &usage, const std::string &downgrade,
|
||||
int32_t scale, bool decode, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DIV2KOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[out] out
|
||||
/// \param[in] show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Function to count the number of samples in the DIV2K dataset.
|
||||
/// \param[in] dir - path to the DIV2K directory.
|
||||
/// \param[in] usage - the type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade - the mode of downgrade. Acceptable downgrades include "bicubic", "unknown",
|
||||
/// "mild", "difficult" or "wild".
|
||||
/// \param[in] scale - the scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[out] count - output arg that will hold the actual dataset size.
|
||||
/// \return Status - The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, const std::string &downgrade,
|
||||
int32_t scale, int64_t *count);
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "DIV2KOp"; }
|
||||
|
||||
private:
|
||||
/// \brief Load a tensor row according to a pair.
|
||||
/// \param[in] uint64_t index - index need to load.
|
||||
/// \param[out] TensorRow row - image & label read into this tensor row.
|
||||
/// \return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
|
||||
|
||||
/// \brief Called first when function is called.
|
||||
/// \return Status - The status code returned.
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Get the real name of high resolution images and low resolution images dir in DIV2K dataset.
|
||||
/// \param[in] hr_dir_key - the key of high resolution images dir.
|
||||
/// \param[in] lr_dir_key - the key of high resolution images dir.
|
||||
/// \return Status - The status code returned.
|
||||
Status GetDIV2KLRDirRealName(const std::string &hr_dir_key, const std::string &lr_dir_key);
|
||||
|
||||
/// \brief Parse DIV2K data.
|
||||
/// \return Status - The status code returned.
|
||||
Status ParseDIV2KData();
|
||||
|
||||
/// \brief Get DIV2K data by usage.
|
||||
/// \return Status - The status code returned.
|
||||
Status GetDIV2KDataByUsage();
|
||||
|
||||
/// \brief Count label index,num rows and num samples.
|
||||
/// \return Status - The status code returned.
|
||||
Status CountDatasetInfo();
|
||||
|
||||
/// \brief Private function for computing the assignment of the column name map.
|
||||
/// \return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
int32_t scale_;
|
||||
std::string downgrade_;
|
||||
bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> image_hr_lr_pairs_;
|
||||
std::string hr_dir_real_name_;
|
||||
std::string lr_dir_real_name_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_
|
|
@ -82,6 +82,7 @@ constexpr char kCityscapesNode[] = "CityscapesDataset";
|
|||
constexpr char kCLUENode[] = "CLUEDataset";
|
||||
constexpr char kCocoNode[] = "CocoDataset";
|
||||
constexpr char kCSVNode[] = "CSVDataset";
|
||||
constexpr char kDIV2KNode[] = "DIV2KDataset";
|
||||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
|
|
|
@ -11,6 +11,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
clue_node.cc
|
||||
coco_node.cc
|
||||
csv_node.cc
|
||||
div2k_node.cc
|
||||
flickr_node.cc
|
||||
image_folder_node.cc
|
||||
manifest_node.cc
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/div2k_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for DIV2KNode
|
||||
DIV2KNode::DIV2KNode(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade,
|
||||
int32_t scale, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
downgrade_(downgrade),
|
||||
scale_(scale),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> DIV2KNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<DIV2KNode>(dataset_dir_, usage_, downgrade_, scale_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void DIV2KNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(dataset dir:" + dataset_dir_;
|
||||
out << ", usage:" + usage_ << ", scale:" + std::to_string(scale_) << ", downgrade:" + downgrade_;
|
||||
if (sampler_ != nullptr) {
|
||||
out << ", sampler";
|
||||
}
|
||||
if (cache_ != nullptr) {
|
||||
out << ", cache";
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
|
||||
Status DIV2KNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("DIV2KNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("DIV2KNode", usage_, {"train", "valid", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("DIV2KNode", downgrade_, {"bicubic", "unknown", "mild", "difficult", "wild"}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("DIV2KNode", sampler_));
|
||||
|
||||
std::set<int32_t> scale_arr = {2, 3, 4, 8};
|
||||
auto s_it = scale_arr.find(scale_);
|
||||
if (s_it == scale_arr.end()) {
|
||||
std::string err_msg = "DIV2KNode: " + std::to_string(scale_) + " does not match any mode in [2, 3, 4, 8].";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (scale_ == 8 && downgrade_ != "bicubic") {
|
||||
std::string err_msg = "DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
std::set<std::string> downgrade_2018 = {"mild", "difficult", "wild"};
|
||||
auto it = downgrade_2018.find(downgrade_);
|
||||
if (it != downgrade_2018.end() && scale_ != 4) {
|
||||
std::string err_msg = "DIV2KNode: " + downgrade_ + " downgrade requires scale equal to 4.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build DIV2KOp for DIV2K
|
||||
Status DIV2KNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("hr_image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("lr_image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||
div2k_op->set_total_repeats(GetTotalRepeats());
|
||||
div2k_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(div2k_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status DIV2KNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status DIV2KNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(DIV2KOp::CountTotalRows(dataset_dir_, usage_, downgrade_, scale_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DIV2KNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["downgrade"] = downgrade_;
|
||||
args["scale"] = scale_;
|
||||
args["decode"] = decode_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DIV2KNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
DIV2KNode(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade, int32_t scale,
|
||||
bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DIV2KNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kDIV2KNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[out] node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const int32_t &Scale() const { return scale_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &Downgrade() const { return downgrade_; }
|
||||
|
||||
/// \brief Getter functions
|
||||
bool Decode() const { return decode_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
int32_t scale_;
|
||||
std::string downgrade_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_
|
|
@ -1524,6 +1524,112 @@ inline std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_f
|
|||
cache);
|
||||
}
|
||||
|
||||
/// \class DIV2KDataset
|
||||
/// \brief A source dataset for reading and parsing DIV2K dataset.
|
||||
class DIV2KDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of DIV2KDataset.
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
|
||||
/// "difficult" or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of DIV2KDataset.
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
|
||||
/// "difficult" or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of DIV2KDataset.
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
|
||||
/// "difficult" or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of DIV2KDataset.
|
||||
~DIV2KDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a DIV2KDataset.
|
||||
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
|
||||
/// or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current DIV2KDataset.
|
||||
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::string &downgrade, int32_t scale, bool decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
|
||||
decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a DIV2KDataset.
|
||||
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
|
||||
/// or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current DIV2KDataset.
|
||||
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::string &downgrade, int32_t scale, bool decode,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
|
||||
decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a DIV2KDataset.
|
||||
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
|
||||
/// \param[in] dataset_dir The dataset dir to be read.
|
||||
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
|
||||
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
|
||||
/// or "wild".
|
||||
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current DIV2KDataset.
|
||||
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::string &downgrade, int32_t scale, bool decode,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
|
||||
decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class FlickrDataset
|
||||
/// \brief A source dataset for reading and parsing Flickr dataset.
|
||||
class FlickrDataset : public Dataset {
|
||||
|
|
|
@ -38,6 +38,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class CLUEDataset;
|
||||
friend class CocoDataset;
|
||||
friend class CSVDataset;
|
||||
friend class DIV2KDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class ManifestDataset;
|
||||
|
|
|
@ -64,7 +64,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
|
||||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -6643,3 +6643,190 @@ class CityscapesDataset(MappableDataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler)
|
||||
|
||||
|
||||
class DIV2KDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing DIV2KDataset dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[hr_image, lr_image]`.
|
||||
The tensor of column :py:obj:`hr_image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`lr_image` is of the uint8 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str): Acceptable usages include `train`, `valid` or `all` (default=`train`).
|
||||
downgrade (str): Acceptable downgrades include `bicubic`, `unknown`, `mild`, `difficult` or
|
||||
`wild` (default=`bicubic`).
|
||||
scale (int): Acceptable scales include 2, 3, 4 or 8 (default=2).
|
||||
When `downgrade` is `bicubic`, scale can be 2, 3, 4, 8.
|
||||
When `downgrade` is `unknown`, scale can only be 2, 3, 4.
|
||||
When `downgrade` is `mild`, `difficult` or `wild`, scale can only be 4.
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
|
||||
order behavior shown in the table).
|
||||
decode (bool, optional): Decode the images after reading (default=False).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None). When this argument is specified, `num_samples` reflects
|
||||
the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir is invalid or does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If dataset_dir is not exist.
|
||||
ValueError: If usage is invalid.
|
||||
ValueError: If downgrade is invalid.
|
||||
ValueError: If scale is invalid.
|
||||
ValueError: If scale equal to 8 and downgrade not equal to `bicubic`.
|
||||
ValueError: If downgrade in [`mild`, `difficult`, `wild`] and scale not equal to 4.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> div2k_dataset_dir = "/path/to/div2k_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Get all samples from DIV2K dataset in sequence
|
||||
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
|
||||
>>> shuffle=False)
|
||||
>>>
|
||||
>>> # 2) Randomly select 350 samples from DIV2K dataset
|
||||
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
|
||||
>>> num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 3) Get samples from DIV2K dataset for shard 0 in a 2-way distributed training
|
||||
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
|
||||
>>> num_shards=2, shard_id=0)
|
||||
>>>
|
||||
>>> # In DIV2K dataset, each dictionary has keys "hr_image" and "lr_image"
|
||||
|
||||
About DIV2K dataset:
|
||||
|
||||
The DIV2K dataset consists of 1000 2K resolution images, among which 800 images are for training, 100 images
|
||||
are for validation and 100 images are for testing. NTIRE 2017 and NTIRE 2018 include only training dataset
|
||||
and validation dataset.
|
||||
|
||||
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||
|
||||
Take the training set as an example.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── DIV2K
|
||||
├── DIV2K_train_HR
|
||||
| ├── 0001.png
|
||||
| ├── 0002.png
|
||||
| ├── ...
|
||||
├── DIV2K_train_LR_bicubic
|
||||
| ├── X2
|
||||
| | ├── 0001x2.png
|
||||
| | ├── 0002x2.png
|
||||
| | ├── ...
|
||||
| ├── X3
|
||||
| | ├── 0001x3.png
|
||||
| | ├── 0002x3.png
|
||||
| | ├── ...
|
||||
| └── X4
|
||||
| ├── 0001x4.png
|
||||
| ├── 0002x4.png
|
||||
| ├── ...
|
||||
├── DIV2K_train_LR_unknown
|
||||
| ├── X2
|
||||
| | ├── 0001x2.png
|
||||
| | ├── 0002x2.png
|
||||
| | ├── ...
|
||||
| ├── X3
|
||||
| | ├── 0001x3.png
|
||||
| | ├── 0002x3.png
|
||||
| | ├── ...
|
||||
| └── X4
|
||||
| ├── 0001x4.png
|
||||
| ├── 0002x4.png
|
||||
| ├── ...
|
||||
├── DIV2K_train_LR_mild
|
||||
| ├── 0001x4m.png
|
||||
| ├── 0002x4m.png
|
||||
| ├── ...
|
||||
├── DIV2K_train_LR_difficult
|
||||
| ├── 0001x4d.png
|
||||
| ├── 0002x4d.png
|
||||
| ├── ...
|
||||
├── DIV2K_train_LR_wild
|
||||
| ├── 0001x4w.png
|
||||
| ├── 0002x4w.png
|
||||
| ├── ...
|
||||
└── DIV2K_train_LR_x8
|
||||
├── 0001x8.png
|
||||
├── 0002x8.png
|
||||
├── ...
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@InProceedings{Agustsson_2017_CVPR_Workshops,
|
||||
author = {Agustsson, Eirikur and Timofte, Radu},
|
||||
title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
|
||||
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
|
||||
url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf",
|
||||
month = {July},
|
||||
year = {2017}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_div2k_dataset
|
||||
def __init__(self, dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None,
|
||||
shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = usage
|
||||
self.scale = scale
|
||||
self.downgrade = downgrade
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler)
|
||||
|
|
|
@ -1489,3 +1489,45 @@ def check_cityscapes_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_div2k_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original DIV2KDataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
check_valid_str(usage, ['train', 'valid', 'all'], "usage")
|
||||
|
||||
downgrade = param_dict.get('downgrade')
|
||||
check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
|
||||
|
||||
validate_dataset_param_value(['scale'], param_dict, int)
|
||||
scale = param_dict.get('scale')
|
||||
scale_values = [2, 3, 4, 8]
|
||||
if scale not in scale_values:
|
||||
raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
|
||||
|
||||
if scale == 8 and downgrade != "bicubic":
|
||||
raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
|
||||
|
||||
downgrade_2018 = ["mild", "difficult", "wild"]
|
||||
if downgrade in downgrade_2018 and scale != 4:
|
||||
raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -22,6 +22,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_coco_test.cc
|
||||
c_api_dataset_config_test.cc
|
||||
c_api_dataset_csv_test.cc
|
||||
c_api_dataset_div2k_test.cc
|
||||
c_api_dataset_flickr_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
|
|
|
@ -0,0 +1,305 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KBasic.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create a DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto hr_image = row["hr_image"];
|
||||
auto lr_image = row["lr_image"];
|
||||
MS_LOG(INFO) << "Tensor hr_image shape: " << hr_image.Shape();
|
||||
MS_LOG(INFO) << "Tensor lr_image shape: " << lr_image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KBasicWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KBasicWithPipeline.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create two DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds1 =
|
||||
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 2));
|
||||
std::shared_ptr<Dataset> ds2 =
|
||||
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 3));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 3;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"hr_image", "lr_image"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["hr_image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 12);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KGetters.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create a DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds1 =
|
||||
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 2));
|
||||
std::shared_ptr<Dataset> ds2 =
|
||||
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 3));
|
||||
std::vector<std::string> column_names = {"hr_image", "lr_image"};
|
||||
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_EQ(ds1->GetDatasetSize(), 2);
|
||||
EXPECT_EQ(ds1->GetColumnNames(), column_names);
|
||||
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_EQ(ds2->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ds2->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KDecode) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KDecode.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create a DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale, true, std::make_shared<RandomSampler>());
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto hr_image = row["hr_image"];
|
||||
auto lr_image = row["lr_image"];
|
||||
auto h_size = hr_image.Shape().size();
|
||||
auto l_size = lr_image.Shape().size();
|
||||
MS_LOG(INFO) << "Tensor hr_image shape size: " << h_size;
|
||||
MS_LOG(INFO) << "Tensor lr_image shape size: " << l_size;
|
||||
EXPECT_GT(h_size, 1); // Verify decode=true took effect
|
||||
EXPECT_GT(l_size, 1); // Verify decode=true took effect
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KNumSamplers) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KNumSamplers.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create a DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds =
|
||||
DIV2K(dataset_path, usage, downgrade, scale, true, std::make_shared<SequentialSampler>(0, 1));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto hr_image = row["hr_image"];
|
||||
auto lr_image = row["lr_image"];
|
||||
|
||||
MS_LOG(INFO) << "Tensor hr_image shape: " << hr_image.Shape();
|
||||
MS_LOG(INFO) << "Tensor lr_image shape: " << lr_image.Shape();
|
||||
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 1);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
std::string downgrade = "unknown"; // bicubic, unknown, mild, difficult, wild
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
|
||||
// Create a DIV2K Dataset with non-existing dataset dir
|
||||
std::shared_ptr<Dataset> ds0 = DIV2K("NotExistFile", usage, downgrade, scale);
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
|
||||
// Create a DIV2K Dataset with err usage
|
||||
std::shared_ptr<Dataset> ds1 = DIV2K(dataset_path, "test", downgrade, scale);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
// Create a DIV2K Dataset with err scale
|
||||
std::shared_ptr<Dataset> ds2 = DIV2K(dataset_path, usage, downgrade, 16);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
|
||||
// Create a DIV2K Dataset with err downgrade
|
||||
std::shared_ptr<Dataset> ds3 = DIV2K(dataset_path, usage, "downgrade", scale);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
|
||||
// Create a DIV2K Dataset with scale 8 and downgrade unknown
|
||||
std::shared_ptr<Dataset> ds4 = DIV2K(dataset_path, usage, "unknown", 8);
|
||||
EXPECT_NE(ds4, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter4, nullptr);
|
||||
|
||||
// Create a DIV2K Dataset with scale 2 and downgrade mild
|
||||
std::shared_ptr<Dataset> ds5 = DIV2K(dataset_path, usage, "mild", 2);
|
||||
EXPECT_NE(ds5, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter5 = ds5->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input
|
||||
EXPECT_EQ(iter5, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDIV2KWithNullSamplerError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KWithNullSamplerError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
|
||||
std::string usage = "train"; // train valid, all
|
||||
int32_t scale = 2; // 2, 3, 4, 8
|
||||
std::string downgrade = "unknown"; // bicubic, unknown, mild, difficult, wild
|
||||
|
||||
// Create a DIV2K Dataset
|
||||
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale, false, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid DIV2K input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 31 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 31 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 17 KiB |
|
@ -0,0 +1,235 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
|
||||
|
||||
DATASET_DIR = "../data/dataset/testDIV2KData/div2k"
|
||||
|
||||
|
||||
def test_div2k_basic(plot=False):
|
||||
usage = "train" # train, valid, all
|
||||
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
|
||||
scale = 2 # 2, 3, 4, 8
|
||||
|
||||
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
|
||||
count = 0
|
||||
hr_images_list = []
|
||||
lr_images_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
hr_images_list.append(item['hr_image'])
|
||||
lr_images_list.append(item['lr_image'])
|
||||
count = count + 1
|
||||
assert count == 5
|
||||
if plot:
|
||||
flag = "{}_{}_{}".format(usage, scale, downgrade)
|
||||
visualize_dataset(hr_images_list, lr_images_list, flag)
|
||||
|
||||
|
||||
def visualize_dataset(hr_images_list, lr_images_list, flag):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
image_num = len(hr_images_list)
|
||||
for i in range(image_num):
|
||||
plt.subplot(121)
|
||||
plt.imshow(hr_images_list[i])
|
||||
plt.title('Original')
|
||||
plt.subplot(122)
|
||||
plt.imshow(lr_images_list[i])
|
||||
plt.title(flag)
|
||||
plt.savefig('./div2k_{}_{}.jpg'.format(flag, str(i)))
|
||||
|
||||
|
||||
def test_div2k_basic_func():
|
||||
# case 0: test usage equal to `all`
|
||||
usage = "all" # train, valid, all
|
||||
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
|
||||
scale = 2 # 2, 3, 4, 8
|
||||
|
||||
data0 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
|
||||
num_iter0 = 0
|
||||
for _ in data0.create_dict_iterator(num_epochs=1):
|
||||
num_iter0 += 1
|
||||
assert num_iter0 == 6
|
||||
|
||||
# case 1: test num_samples
|
||||
usage = "train" # train, valid, all
|
||||
|
||||
data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=4)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 4
|
||||
|
||||
# case 2: test repeat
|
||||
data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=3)
|
||||
data2 = data2.repeat(5)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 15
|
||||
|
||||
# case 3: test batch with drop_remainder=False
|
||||
data3 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
|
||||
assert data3.get_dataset_size() == 5
|
||||
assert data3.get_batch_size() == 1
|
||||
resize_op = c_vision.Resize([100, 100])
|
||||
data3 = data3.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
data3 = data3.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1)
|
||||
data3 = data3.batch(batch_size=3) # drop_remainder is default to be False
|
||||
assert data3.get_dataset_size() == 2
|
||||
assert data3.get_batch_size() == 3
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 2
|
||||
|
||||
# case 4: test batch with drop_remainder=True
|
||||
data4 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
|
||||
assert data4.get_dataset_size() == 5
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
data4 = data4.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1)
|
||||
data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data4.get_dataset_size() == 1
|
||||
assert data4.get_batch_size() == 3
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 1
|
||||
|
||||
# case 5: test get_col_names
|
||||
data5 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=1)
|
||||
assert data5.get_col_names() == ["hr_image", "lr_image"]
|
||||
|
||||
|
||||
def test_div2k_sequential_sampler():
|
||||
"""
|
||||
Test DIV2KDataset with SequentialSampler
|
||||
"""
|
||||
usage = "train" # train, valid, all
|
||||
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
|
||||
scale = 2 # 2, 3, 4, 8
|
||||
|
||||
num_samples = 2
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, sampler=sampler)
|
||||
data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
|
||||
num_samples=num_samples)
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["hr_image"], item2["hr_image"])
|
||||
np.testing.assert_array_equal(item1["lr_image"], item2["lr_image"])
|
||||
num_iter += 1
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_div2k_exception():
|
||||
usage = "train" # train, valid, all
|
||||
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
|
||||
scale = 2 # 2, 3, 4, 8
|
||||
|
||||
error_msg_1 = "does not exist or is not a directory or permission denied!"
|
||||
with pytest.raises(ValueError, match=error_msg_1):
|
||||
ds.DIV2KDataset("NoExistsDir", usage=usage, downgrade=downgrade, scale=scale)
|
||||
|
||||
error_msg_2 = r"Input usage is not within the valid set of \['train', 'valid', 'all'\]."
|
||||
with pytest.raises(ValueError, match=error_msg_2):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage="test", downgrade=downgrade, scale=scale)
|
||||
|
||||
error_msg_3 = r"Input scale is not within the valid set of \[2, 3, 4, 8\]."
|
||||
with pytest.raises(ValueError, match=error_msg_3):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=16, downgrade=downgrade)
|
||||
|
||||
error_msg_4 = r"Input downgrade is not within the valid set of .*"
|
||||
with pytest.raises(ValueError, match=error_msg_4):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=scale, downgrade="downgrade")
|
||||
|
||||
error_msg_5 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_5):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
|
||||
sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_6 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_6):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=0,
|
||||
sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_7 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_7):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=10)
|
||||
|
||||
error_msg_8 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shard_id=0)
|
||||
|
||||
error_msg_9 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_9):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_9):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_9):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_10 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_10):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
|
||||
num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_10):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
|
||||
num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_10):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
|
||||
num_parallel_workers=-2)
|
||||
|
||||
error_msg_11 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_11):
|
||||
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
|
||||
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
|
||||
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_div2k_basic()
|
||||
test_div2k_basic_func()
|
||||
test_div2k_sequential_sampler()
|
||||
test_div2k_exception()
|