[feat] [assistant] [I3CKEK] add new dataset operator DIV2K

This commit is contained in:
wangkc123 2021-09-10 09:37:36 +08:00
parent 4ec1824913
commit 3daf3a4d5f
33 changed files with 1599 additions and 1 deletions

View File

@ -95,6 +95,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
@ -982,6 +983,33 @@ CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -32,6 +32,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
@ -137,6 +138,18 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
"to create a DIV2KNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale,
bool decode, py::handle sampler) {
auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(div2k->ValidateParams());
return div2k;
}));
}));
PYBIND_REGISTER(
FlickrNode, 2, ([](const py::module *m) {
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")

View File

@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
mappable_leaf_op.cc
nonmappable_leaf_op.cc
cityscapes_op.cc
div2k_op.cc
flickr_op.cc
)

View File

@ -0,0 +1,285 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/div2k_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
const std::map<std::string, std::string> DatasetPramMap = {{"train_hr", "DIV2K_train_HR"},
{"valid_hr", "DIV2K_valid_HR"},
{"train_bicubic_x2", "DIV2K_train_LR_bicubic"},
{"train_unknown_x2", "DIV2K_train_LR_unknown"},
{"valid_bicubic_x2", "DIV2K_valid_LR_bicubic"},
{"valid_unknown_x2", "DIV2K_valid_LR_unknown"},
{"train_bicubic_x3", "DIV2K_train_LR_bicubic"},
{"train_unknown_x3", "DIV2K_train_LR_unknown"},
{"valid_bicubic_x3", "DIV2K_valid_LR_bicubic"},
{"valid_unknown_x3", "DIV2K_valid_LR_unknown"},
{"train_bicubic_x4", "DIV2K_train_LR_bicubic"},
{"train_unknown_x4", "DIV2K_train_LR_unknown"},
{"valid_bicubic_x4", "DIV2K_valid_LR_bicubic"},
{"valid_unknown_x4", "DIV2K_valid_LR_unknown"},
{"train_bicubic_x8", "DIV2K_train_LR_x8"},
{"valid_bicubic_x8", "DIV2K_valid_LR_x8"},
{"train_mild_x4", "DIV2K_train_LR_mild"},
{"valid_mild_x4", "DIV2K_valid_LR_mild"},
{"train_difficult_x4", "DIV2K_train_LR_difficult"},
{"valid_difficult_x4", "DIV2K_valid_LR_difficult"},
{"train_wild_x4", "DIV2K_train_LR_wild"},
{"valid_wild_x4", "DIV2K_valid_LR_wild"}};
DIV2KOp::DIV2KOp(int32_t num_workers, const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
dataset_dir_(dataset_dir),
usage_(usage),
downgrade_(downgrade),
scale_(scale),
decode_(decode),
data_schema_(std::move(data_schema)) {
io_block_queues_.Init(num_workers_, queue_size);
}
Status DIV2KOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&DIV2KOp::WorkerEntry, this, std::placeholders::_1), "", id()));
TaskManager::FindMe()->Post();
// The order of the following 3 functions must not be changed!
RETURN_IF_NOT_OK(ParseDIV2KData()); // Parse div2k data and get num rows, blocking
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
return Status::OK();
}
// Load 1 TensorRow (hr_image, lr_image) using 1 ImageLabelPair. 1 function call produces 1 TensorTow.
Status DIV2KOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
std::pair<std::string, std::string> data = image_hr_lr_pairs_[static_cast<size_t>(row_id)];
std::shared_ptr<Tensor> hr_image;
std::shared_ptr<Tensor> lr_image;
RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.first, &hr_image));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.second, &lr_image));
if (decode_ == true) {
Status hr_rc = Decode(hr_image, &hr_image);
if (hr_rc.IsError()) {
std::string err = "Invalid data, failed to decode image: " + data.first;
RETURN_STATUS_UNEXPECTED(err);
}
Status lr_rc = Decode(lr_image, &lr_image);
if (lr_rc.IsError()) {
std::string err = "Invalid data, failed to decode image: " + data.second;
RETURN_STATUS_UNEXPECTED(err);
}
}
(*trow) = TensorRow(row_id, {std::move(hr_image), std::move(lr_image)});
trow->setPath({data.first, data.second});
return Status::OK();
}
void DIV2KOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows:" << num_rows_ << "\nDIV2K DatasetDir: " << dataset_dir_ << "\nUsage: " << usage_
<< "\nScale: " << scale_ << "\nDowngrade: " << downgrade_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
Status DIV2KOp::ParseDIV2KData() {
std::string hr_dir_key;
std::string lr_dir_key;
if (usage_ == "all") {
std::vector<std::string> usage_all = {"train", "valid"};
for (auto &item : usage_all) {
hr_dir_key = item + "_hr";
lr_dir_key = item + "_" + downgrade_ + "_x" + std::to_string(scale_);
RETURN_IF_NOT_OK(GetDIV2KLRDirRealName(hr_dir_key, lr_dir_key));
RETURN_IF_NOT_OK(GetDIV2KDataByUsage());
}
} else {
hr_dir_key = usage_ + "_hr";
lr_dir_key = usage_ + "_" + downgrade_ + "_x" + std::to_string(scale_);
RETURN_IF_NOT_OK(GetDIV2KLRDirRealName(hr_dir_key, lr_dir_key));
RETURN_IF_NOT_OK(GetDIV2KDataByUsage());
}
return Status::OK();
}
Status DIV2KOp::GetDIV2KLRDirRealName(const std::string &hr_dir_key, const std::string &lr_dir_key) {
std::set<std::string> downgrade_2017 = {"bicubic", "unknown"};
std::set<int32_t> scale_2017 = {2, 3, 4};
hr_dir_real_name_ = DatasetPramMap.find(hr_dir_key)->second;
auto lr_it = DatasetPramMap.find(lr_dir_key);
if (lr_it == DatasetPramMap.end()) {
std::string out_str = "{\n";
std::for_each(DatasetPramMap.begin(), DatasetPramMap.end(),
[&out_str](std::pair<std::string, std::string> item) -> void {
out_str += ("\t" + item.first + ": " + item.second + ",\n");
});
out_str += "\n}";
RETURN_STATUS_UNEXPECTED("Invalid param, " + lr_dir_key + " not found in DatasetPramMap: \n" + out_str);
}
if (downgrade_2017.find(downgrade_) != downgrade_2017.end() && scale_2017.find(scale_) != scale_2017.end()) {
Path ntire_2017(lr_it->second);
lr_dir_real_name_ = (ntire_2017 / ("X" + std::to_string(scale_))).ToString();
} else {
lr_dir_real_name_ = lr_it->second;
}
return Status::OK();
}
Status DIV2KOp::GetDIV2KDataByUsage() {
const std::string kExtension = ".png";
auto real_dataset_dir = Common::GetRealPath(dataset_dir_);
if (!real_dataset_dir.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_);
}
Path dataset_dir(real_dataset_dir.value());
Path hr_images_dir = dataset_dir / hr_dir_real_name_;
Path lr_images_dir = dataset_dir / lr_dir_real_name_;
if (!hr_images_dir.IsDirectory()) {
RETURN_STATUS_UNEXPECTED("Invalid path, " + hr_images_dir.ToString() + " is an invalid directory path.");
}
if (!lr_images_dir.IsDirectory()) {
RETURN_STATUS_UNEXPECTED("Invalid path, " + lr_images_dir.ToString() + " is an invalid directory path.");
}
auto hr_it = Path::DirIterator::OpenDirectory(&hr_images_dir);
if (hr_it == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid path, failed to open directory: " + hr_images_dir.ToString());
}
std::string image_name;
std::string image_id_scale;
std::string lr_image_file_path_;
std::map<std::string, std::string> image_hr_lr_map_;
std::map<std::string, std::string> downgrade_2018 = {{"mild", "m"}, {"difficult", "d"}, {"wild", "w"}};
while (hr_it->HasNext()) {
try {
Path hr_img_file = hr_it->Next();
if (hr_img_file.Extension() != kExtension) {
continue;
}
image_name = hr_img_file.Basename();
image_id_scale = image_name.substr(0, image_name.find_last_of(".")) + "x" + std::to_string(scale_);
Path hr_image_file_path = hr_images_dir / image_name;
auto lr_it = downgrade_2018.find(downgrade_);
if (lr_it != downgrade_2018.end()) {
lr_image_file_path_ = (lr_images_dir / (image_id_scale + lr_it->second + kExtension)).ToString();
} else {
lr_image_file_path_ = (lr_images_dir / (image_id_scale + kExtension)).ToString();
}
Path lr_image_file_path(lr_image_file_path_);
if (!lr_image_file_path.Exists()) {
RETURN_STATUS_UNEXPECTED("Invalid file, " + lr_image_file_path.ToString() + " not found.");
}
image_hr_lr_map_[hr_image_file_path.ToString()] = lr_image_file_path.ToString();
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid path, failed to load DIV2K Dataset: " + dataset_dir_);
}
}
for (auto item : image_hr_lr_map_) {
image_hr_lr_pairs_.emplace_back(std::make_pair(item.first, item.second));
}
return Status::OK();
}
Status DIV2KOp::CountDatasetInfo() {
num_rows_ = static_cast<int64_t>(image_hr_lr_pairs_.size());
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API DIV2KDataset. Please check file path or dataset API.");
}
return Status::OK();
}
Status DIV2KOp::CountTotalRows(const std::string &dir, const std::string &usage, const std::string &downgrade,
int32_t scale, int64_t *count) {
// the logic of counting the number of samples is copied from ParseDIV2KData()
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto new_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
// build a new unique schema object
auto new_schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(new_schema->AddColumn(ColDescriptor("hr_image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
new_schema->AddColumn(ColDescriptor("lr_image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
std::shared_ptr<DIV2KOp> op = std::make_shared<DIV2KOp>(
num_workers, dir, usage, downgrade, scale, false, op_connect_size, std::move(new_schema), std::move(new_sampler));
RETURN_IF_NOT_OK(op->ParseDIV2KData());
*count = static_cast<int64_t>(op->image_hr_lr_pairs_.size());
return Status::OK();
}
Status DIV2KOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
class DIV2KOp : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] int32_t num_workers - num of workers reading images in parallel.
/// \param[in] std::string dataset_dir - dir directory of DIV2K dataset.
/// \param[in] std::string usage - the type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] std::string downgrade - the mode of downgrade. Acceptable downgrades include "bicubic", "unknown",
/// "mild", "difficult" or "wild".
/// \param[in] int32_t scale - the scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] bool decode - decode the images after reading.
/// \param[in] int32_t queue_size - connector queue size.
/// \param[in] DataSchema data_schema - the schema of each column in output data.
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read.
DIV2KOp(int32_t num_workers, const std::string &dataset_dir, const std::string &usage, const std::string &downgrade,
int32_t scale, bool decode, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor.
~DIV2KOp() = default;
/// \brief A print method typically used for debugging.
/// \param[out] out
/// \param[in] show_all
void Print(std::ostream &out, bool show_all) const override;
/// \brief Function to count the number of samples in the DIV2K dataset.
/// \param[in] dir - path to the DIV2K directory.
/// \param[in] usage - the type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade - the mode of downgrade. Acceptable downgrades include "bicubic", "unknown",
/// "mild", "difficult" or "wild".
/// \param[in] scale - the scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[out] count - output arg that will hold the actual dataset size.
/// \return Status - The status code returned.
static Status CountTotalRows(const std::string &dir, const std::string &usage, const std::string &downgrade,
int32_t scale, int64_t *count);
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "DIV2KOp"; }
private:
/// \brief Load a tensor row according to a pair.
/// \param[in] uint64_t index - index need to load.
/// \param[out] TensorRow row - image & label read into this tensor row.
/// \return Status - The status code returned.
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
/// \brief Called first when function is called.
/// \return Status - The status code returned.
Status LaunchThreadsAndInitOp() override;
/// \brief Get the real name of high resolution images and low resolution images dir in DIV2K dataset.
/// \param[in] hr_dir_key - the key of high resolution images dir.
/// \param[in] lr_dir_key - the key of high resolution images dir.
/// \return Status - The status code returned.
Status GetDIV2KLRDirRealName(const std::string &hr_dir_key, const std::string &lr_dir_key);
/// \brief Parse DIV2K data.
/// \return Status - The status code returned.
Status ParseDIV2KData();
/// \brief Get DIV2K data by usage.
/// \return Status - The status code returned.
Status GetDIV2KDataByUsage();
/// \brief Count label index,num rows and num samples.
/// \return Status - The status code returned.
Status CountDatasetInfo();
/// \brief Private function for computing the assignment of the column name map.
/// \return Status - The status code returned.
Status ComputeColMap() override;
std::string dataset_dir_;
std::string usage_;
int32_t scale_;
std::string downgrade_;
bool decode_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::pair<std::string, std::string>> image_hr_lr_pairs_;
std::string hr_dir_real_name_;
std::string lr_dir_real_name_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_DIV2K_OP_H_

View File

@ -82,6 +82,7 @@ constexpr char kCityscapesNode[] = "CityscapesDataset";
constexpr char kCLUENode[] = "CLUEDataset";
constexpr char kCocoNode[] = "CocoDataset";
constexpr char kCSVNode[] = "CSVDataset";
constexpr char kDIV2KNode[] = "DIV2KDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";

View File

@ -11,6 +11,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
clue_node.cc
coco_node.cc
csv_node.cc
div2k_node.cc
flickr_node.cc
image_folder_node.cc
manifest_node.cc

View File

@ -0,0 +1,158 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/div2k_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for DIV2KNode
DIV2KNode::DIV2KNode(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade,
int32_t scale, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
downgrade_(downgrade),
scale_(scale),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> DIV2KNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<DIV2KNode>(dataset_dir_, usage_, downgrade_, scale_, decode_, sampler, cache_);
return node;
}
void DIV2KNode::Print(std::ostream &out) const {
out << Name() + "(dataset dir:" + dataset_dir_;
out << ", usage:" + usage_ << ", scale:" + std::to_string(scale_) << ", downgrade:" + downgrade_;
if (sampler_ != nullptr) {
out << ", sampler";
}
if (cache_ != nullptr) {
out << ", cache";
}
out << ")";
}
Status DIV2KNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("DIV2KNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("DIV2KNode", usage_, {"train", "valid", "all"}));
RETURN_IF_NOT_OK(ValidateStringValue("DIV2KNode", downgrade_, {"bicubic", "unknown", "mild", "difficult", "wild"}));
RETURN_IF_NOT_OK(ValidateDatasetSampler("DIV2KNode", sampler_));
std::set<int32_t> scale_arr = {2, 3, 4, 8};
auto s_it = scale_arr.find(scale_);
if (s_it == scale_arr.end()) {
std::string err_msg = "DIV2KNode: " + std::to_string(scale_) + " does not match any mode in [2, 3, 4, 8].";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (scale_ == 8 && downgrade_ != "bicubic") {
std::string err_msg = "DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::set<std::string> downgrade_2018 = {"mild", "difficult", "wild"};
auto it = downgrade_2018.find(downgrade_);
if (it != downgrade_2018.end() && scale_ != 4) {
std::string err_msg = "DIV2KNode: " + downgrade_ + " downgrade requires scale equal to 4.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Function to build DIV2KOp for DIV2K
Status DIV2KNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("hr_image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("lr_image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_,
connector_que_size_, std::move(schema), std::move(sampler_rt));
div2k_op->set_total_repeats(GetTotalRepeats());
div2k_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(div2k_op);
return Status::OK();
}
// Get the shard id of node
Status DIV2KNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status DIV2KNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(DIV2KOp::CountTotalRows(dataset_dir_, usage_, downgrade_, scale_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status DIV2KNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["downgrade"] = downgrade_;
args["scale"] = scale_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,109 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class DIV2KNode : public MappableSourceNode {
public:
/// \brief Constructor.
DIV2KNode(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade, int32_t scale,
bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~DIV2KNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kDIV2KNode; }
/// \brief Print the description.
/// \param[out] out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[out] node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Getter functions.
const std::string &Usage() const { return usage_; }
/// \brief Getter functions.
const int32_t &Scale() const { return scale_; }
/// \brief Getter functions.
const std::string &Downgrade() const { return downgrade_; }
/// \brief Getter functions
bool Decode() const { return decode_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
int32_t scale_;
std::string downgrade_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_DIV2K_NODE_H_

View File

@ -1524,6 +1524,112 @@ inline std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_f
cache);
}
/// \class DIV2KDataset
/// \brief A source dataset for reading and parsing DIV2K dataset.
class DIV2KDataset : public Dataset {
public:
/// \brief Constructor of DIV2KDataset.
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
/// "difficult" or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of DIV2KDataset.
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
/// "difficult" or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of DIV2KDataset.
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild",
/// "difficult" or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of DIV2KDataset.
~DIV2KDataset() = default;
};
/// \brief Function to create a DIV2KDataset.
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
/// or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current DIV2KDataset.
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
/// \brief Function to create a DIV2KDataset.
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
/// or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current DIV2KDataset.
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
/// \brief Function to create a DIV2KDataset.
/// \note The generated dataset has two columns ["hr_image", "lr_image"].
/// \param[in] dataset_dir The dataset dir to be read.
/// \param[in] usage The type of dataset. Acceptable usages include "train", "valid" or "all".
/// \param[in] downgrade The mode of downgrade. Acceptable downgrades include "bicubic", "unknown", "mild", "difficult"
/// or "wild".
/// \param[in] scale The scale of downgrade. Acceptable scales include 2, 3, 4 or 8.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current DIV2KDataset.
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
/// \class FlickrDataset
/// \brief A source dataset for reading and parsing Flickr dataset.
class FlickrDataset : public Dataset {

View File

@ -38,6 +38,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class CLUEDataset;
friend class CocoDataset;
friend class CSVDataset;
friend class DIV2KDataset;
friend class FlickrDataset;
friend class ImageFolderDataset;
friend class ManifestDataset;

View File

@ -64,7 +64,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -6643,3 +6643,190 @@ class CityscapesDataset(MappableDataset):
def parse(self, children=None):
return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler)
class DIV2KDataset(MappableDataset):
"""
A source dataset for reading and parsing DIV2KDataset dataset.
The generated dataset has two columns :py:obj:`[hr_image, lr_image]`.
The tensor of column :py:obj:`hr_image` is of the uint8 type.
The tensor of column :py:obj:`lr_image` is of the uint8 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str): Acceptable usages include `train`, `valid` or `all` (default=`train`).
downgrade (str): Acceptable downgrades include `bicubic`, `unknown`, `mild`, `difficult` or
`wild` (default=`bicubic`).
scale (int): Acceptable scales include 2, 3, 4 or 8 (default=2).
When `downgrade` is `bicubic`, scale can be 2, 3, 4, 8.
When `downgrade` is `unknown`, scale can only be 2, 3, 4.
When `downgrade` is `mild`, `difficult` or `wild`, scale can only be 4.
num_samples (int, optional): The number of images to be included in the dataset.
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
order behavior shown in the table).
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir is invalid or does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If dataset_dir is not exist.
ValueError: If usage is invalid.
ValueError: If downgrade is invalid.
ValueError: If scale is invalid.
ValueError: If scale equal to 8 and downgrade not equal to `bicubic`.
ValueError: If downgrade in [`mild`, `difficult`, `wild`] and scale not equal to 4.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> div2k_dataset_dir = "/path/to/div2k_dataset_directory"
>>>
>>> # 1) Get all samples from DIV2K dataset in sequence
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
>>> shuffle=False)
>>>
>>> # 2) Randomly select 350 samples from DIV2K dataset
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
>>> num_samples=350, shuffle=True)
>>>
>>> # 3) Get samples from DIV2K dataset for shard 0 in a 2-way distributed training
>>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
>>> num_shards=2, shard_id=0)
>>>
>>> # In DIV2K dataset, each dictionary has keys "hr_image" and "lr_image"
About DIV2K dataset:
The DIV2K dataset consists of 1000 2K resolution images, among which 800 images are for training, 100 images
are for validation and 100 images are for testing. NTIRE 2017 and NTIRE 2018 include only training dataset
and validation dataset.
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
Take the training set as an example.
.. code-block::
.
DIV2K
DIV2K_train_HR
| 0001.png
| 0002.png
| ...
DIV2K_train_LR_bicubic
| X2
| | 0001x2.png
| | 0002x2.png
| | ...
| X3
| | 0001x3.png
| | 0002x3.png
| | ...
| X4
| 0001x4.png
| 0002x4.png
| ...
DIV2K_train_LR_unknown
| X2
| | 0001x2.png
| | 0002x2.png
| | ...
| X3
| | 0001x3.png
| | 0002x3.png
| | ...
| X4
| 0001x4.png
| 0002x4.png
| ...
DIV2K_train_LR_mild
| 0001x4m.png
| 0002x4m.png
| ...
DIV2K_train_LR_difficult
| 0001x4d.png
| 0002x4d.png
| ...
DIV2K_train_LR_wild
| 0001x4w.png
| 0002x4w.png
| ...
DIV2K_train_LR_x8
0001x8.png
0002x8.png
...
Citation:
.. code-block::
@InProceedings{Agustsson_2017_CVPR_Workshops,
author = {Agustsson, Eirikur and Timofte, Radu},
title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf",
month = {July},
year = {2017}
}
"""
@check_div2k_dataset
def __init__(self, dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None,
num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None,
shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = usage
self.scale = scale
self.downgrade = downgrade
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler)

View File

@ -1489,3 +1489,45 @@ def check_cityscapes_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_div2k_dataset(method):
"""A wrapper that wraps a parameter checker around the original DIV2KDataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
check_valid_str(usage, ['train', 'valid', 'all'], "usage")
downgrade = param_dict.get('downgrade')
check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
validate_dataset_param_value(['scale'], param_dict, int)
scale = param_dict.get('scale')
scale_values = [2, 3, 4, 8]
if scale not in scale_values:
raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
if scale == 8 and downgrade != "bicubic":
raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
downgrade_2018 = ["mild", "difficult", "wild"]
if downgrade in downgrade_2018 and scale != 4:
raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method

View File

@ -22,6 +22,7 @@ SET(DE_UT_SRCS
c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_div2k_test.cc
c_api_dataset_flickr_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_manifest_test.cc

View File

@ -0,0 +1,305 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestDIV2KBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KBasic.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create a DIV2K Dataset
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto hr_image = row["hr_image"];
auto lr_image = row["lr_image"];
MS_LOG(INFO) << "Tensor hr_image shape: " << hr_image.Shape();
MS_LOG(INFO) << "Tensor lr_image shape: " << lr_image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDIV2KBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KBasicWithPipeline.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create two DIV2K Dataset
std::shared_ptr<Dataset> ds1 =
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 2));
std::shared_ptr<Dataset> ds2 =
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 3));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 3;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 2;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"hr_image", "lr_image"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["hr_image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 12);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDIV2KGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KGetters.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create a DIV2K Dataset
std::shared_ptr<Dataset> ds1 =
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 2));
std::shared_ptr<Dataset> ds2 =
DIV2K(dataset_path, usage, downgrade, scale, false, std::make_shared<RandomSampler>(false, 3));
std::vector<std::string> column_names = {"hr_image", "lr_image"};
EXPECT_NE(ds1, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 2);
EXPECT_EQ(ds1->GetColumnNames(), column_names);
EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 3);
EXPECT_EQ(ds2->GetColumnNames(), column_names);
}
TEST_F(MindDataTestPipeline, TestDIV2KDecode) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KDecode.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create a DIV2K Dataset
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale, true, std::make_shared<RandomSampler>());
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto hr_image = row["hr_image"];
auto lr_image = row["lr_image"];
auto h_size = hr_image.Shape().size();
auto l_size = lr_image.Shape().size();
MS_LOG(INFO) << "Tensor hr_image shape size: " << h_size;
MS_LOG(INFO) << "Tensor lr_image shape size: " << l_size;
EXPECT_GT(h_size, 1); // Verify decode=true took effect
EXPECT_GT(l_size, 1); // Verify decode=true took effect
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDIV2KNumSamplers) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KNumSamplers.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "bicubic"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create a DIV2K Dataset
std::shared_ptr<Dataset> ds =
DIV2K(dataset_path, usage, downgrade, scale, true, std::make_shared<SequentialSampler>(0, 1));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto hr_image = row["hr_image"];
auto lr_image = row["lr_image"];
MS_LOG(INFO) << "Tensor hr_image shape: " << hr_image.Shape();
MS_LOG(INFO) << "Tensor lr_image shape: " << lr_image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDIV2KError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KError.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
std::string downgrade = "unknown"; // bicubic, unknown, mild, difficult, wild
int32_t scale = 2; // 2, 3, 4, 8
// Create a DIV2K Dataset with non-existing dataset dir
std::shared_ptr<Dataset> ds0 = DIV2K("NotExistFile", usage, downgrade, scale);
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter0, nullptr);
// Create a DIV2K Dataset with err usage
std::shared_ptr<Dataset> ds1 = DIV2K(dataset_path, "test", downgrade, scale);
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter1, nullptr);
// Create a DIV2K Dataset with err scale
std::shared_ptr<Dataset> ds2 = DIV2K(dataset_path, usage, downgrade, 16);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter2, nullptr);
// Create a DIV2K Dataset with err downgrade
std::shared_ptr<Dataset> ds3 = DIV2K(dataset_path, usage, "downgrade", scale);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter3, nullptr);
// Create a DIV2K Dataset with scale 8 and downgrade unknown
std::shared_ptr<Dataset> ds4 = DIV2K(dataset_path, usage, "unknown", 8);
EXPECT_NE(ds4, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter4, nullptr);
// Create a DIV2K Dataset with scale 2 and downgrade mild
std::shared_ptr<Dataset> ds5 = DIV2K(dataset_path, usage, "mild", 2);
EXPECT_NE(ds5, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter5 = ds5->CreateIterator();
// Expect failure: invalid DIV2K input
EXPECT_EQ(iter5, nullptr);
}
TEST_F(MindDataTestPipeline, TestDIV2KWithNullSamplerError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDIV2KWithNullSamplerError.";
std::string dataset_path = datasets_root_path_ + "/testDIV2KData/div2k";
std::string usage = "train"; // train valid, all
int32_t scale = 2; // 2, 3, 4, 8
std::string downgrade = "unknown"; // bicubic, unknown, mild, difficult, wild
// Create a DIV2K Dataset
std::shared_ptr<Dataset> ds = DIV2K(dataset_path, usage, downgrade, scale, false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid DIV2K input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@ -0,0 +1,235 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import matplotlib.pyplot as plt
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
DATASET_DIR = "../data/dataset/testDIV2KData/div2k"
def test_div2k_basic(plot=False):
usage = "train" # train, valid, all
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
scale = 2 # 2, 3, 4, 8
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
count = 0
hr_images_list = []
lr_images_list = []
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
hr_images_list.append(item['hr_image'])
lr_images_list.append(item['lr_image'])
count = count + 1
assert count == 5
if plot:
flag = "{}_{}_{}".format(usage, scale, downgrade)
visualize_dataset(hr_images_list, lr_images_list, flag)
def visualize_dataset(hr_images_list, lr_images_list, flag):
"""
Helper function to visualize the dataset samples
"""
image_num = len(hr_images_list)
for i in range(image_num):
plt.subplot(121)
plt.imshow(hr_images_list[i])
plt.title('Original')
plt.subplot(122)
plt.imshow(lr_images_list[i])
plt.title(flag)
plt.savefig('./div2k_{}_{}.jpg'.format(flag, str(i)))
def test_div2k_basic_func():
# case 0: test usage equal to `all`
usage = "all" # train, valid, all
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
scale = 2 # 2, 3, 4, 8
data0 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
num_iter0 = 0
for _ in data0.create_dict_iterator(num_epochs=1):
num_iter0 += 1
assert num_iter0 == 6
# case 1: test num_samples
usage = "train" # train, valid, all
data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=4)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 4
# case 2: test repeat
data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=3)
data2 = data2.repeat(5)
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 15
# case 3: test batch with drop_remainder=False
data3 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
assert data3.get_dataset_size() == 5
assert data3.get_batch_size() == 1
resize_op = c_vision.Resize([100, 100])
data3 = data3.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1)
data3 = data3.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1)
data3 = data3.batch(batch_size=3) # drop_remainder is default to be False
assert data3.get_dataset_size() == 2
assert data3.get_batch_size() == 3
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 2
# case 4: test batch with drop_remainder=True
data4 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True)
assert data4.get_dataset_size() == 5
assert data4.get_batch_size() == 1
data4 = data4.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1)
data4 = data4.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1)
data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
assert data4.get_dataset_size() == 1
assert data4.get_batch_size() == 3
num_iter4 = 0
for _ in data4.create_dict_iterator(num_epochs=1):
num_iter4 += 1
assert num_iter4 == 1
# case 5: test get_col_names
data5 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=1)
assert data5.get_col_names() == ["hr_image", "lr_image"]
def test_div2k_sequential_sampler():
"""
Test DIV2KDataset with SequentialSampler
"""
usage = "train" # train, valid, all
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
scale = 2 # 2, 3, 4, 8
num_samples = 2
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, sampler=sampler)
data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
num_samples=num_samples)
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["hr_image"], item2["hr_image"])
np.testing.assert_array_equal(item1["lr_image"], item2["lr_image"])
num_iter += 1
assert num_iter == num_samples
def test_div2k_exception():
usage = "train" # train, valid, all
downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild
scale = 2 # 2, 3, 4, 8
error_msg_1 = "does not exist or is not a directory or permission denied!"
with pytest.raises(ValueError, match=error_msg_1):
ds.DIV2KDataset("NoExistsDir", usage=usage, downgrade=downgrade, scale=scale)
error_msg_2 = r"Input usage is not within the valid set of \['train', 'valid', 'all'\]."
with pytest.raises(ValueError, match=error_msg_2):
ds.DIV2KDataset(DATASET_DIR, usage="test", downgrade=downgrade, scale=scale)
error_msg_3 = r"Input scale is not within the valid set of \[2, 3, 4, 8\]."
with pytest.raises(ValueError, match=error_msg_3):
ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=16, downgrade=downgrade)
error_msg_4 = r"Input downgrade is not within the valid set of .*"
with pytest.raises(ValueError, match=error_msg_4):
ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=scale, downgrade="downgrade")
error_msg_5 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_5):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
sampler=ds.PKSampler(3))
error_msg_6 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_6):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=0,
sampler=ds.PKSampler(3))
error_msg_7 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_7):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=10)
error_msg_8 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_8):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shard_id=0)
error_msg_9 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_9):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_9):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_9):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=5)
error_msg_10 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_10):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_10):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_10):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False,
num_parallel_workers=-2)
error_msg_11 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_11):
ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e)
try:
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e)
if __name__ == "__main__":
test_div2k_basic()
test_div2k_basic_func()
test_div2k_sequential_sampler()
test_div2k_exception()