forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6VE] add new data operator PhotoTour
This commit is contained in:
parent
889f3ddc1f
commit
6594b90a58
|
@ -104,6 +104,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
|
@ -1305,6 +1306,33 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
|
|||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
||||
sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
||||
sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
||||
sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const bool small, const bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -46,6 +46,7 @@
|
|||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
|
@ -276,6 +277,17 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>(
|
||||
*m, "PhotoTourNode", "to create a PhotoTourNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) {
|
||||
auto photo_tour =
|
||||
std::make_shared<PhotoTourNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(photo_tour->ValidateParams());
|
||||
return photo_tour;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Places365Node, DatasetNode, std::shared_ptr<Places365Node>>(
|
||||
*m, "Places365Node", "to create a Places365Node")
|
||||
|
|
|
@ -25,6 +25,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
emnist_op.cc
|
||||
fake_image_op.cc
|
||||
places365_op.cc
|
||||
photo_tour_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,399 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/photo_tour_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "mindspore/ccsrc/debug/common.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
constexpr uint32_t kPatchNumPerRow = 16;
|
||||
constexpr uint32_t kPatchNumPerCol = 16;
|
||||
constexpr uint32_t kColPerPatch = 64;
|
||||
constexpr uint32_t kRowPerPatch = 64;
|
||||
|
||||
const std::map<std::string, int> kLens = {
|
||||
{"notredame", 468159}, {"yosemite", 633587}, {"liberty", 450092},
|
||||
{"liberty_harris", 379587}, {"yosemite_harris", 450912}, {"notredame_harris", 325295},
|
||||
};
|
||||
constexpr char kImageExt[] = "bmp";
|
||||
constexpr char kInfoFile[] = "info.txt";
|
||||
constexpr char kMatchesFiles[] = "m50_100000_100000_0.txt";
|
||||
const std::map<std::string, bool> kTrain = {{"train", true}, {"test", false}};
|
||||
|
||||
PhotoTourOp::PhotoTourOp(const std::string &dataset_dir, const std::string &name, const std::string &usage,
|
||||
int32_t num_workers, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(dataset_dir),
|
||||
name_(name),
|
||||
usage_(usage),
|
||||
buf_cnt_(0),
|
||||
data_schema_(std::move(data_schema)),
|
||||
image_names_({}),
|
||||
image_bmps_({}),
|
||||
matches_({}),
|
||||
labels_({}) {}
|
||||
|
||||
Status PhotoTourOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
if (train_) {
|
||||
std::shared_ptr<Tensor> image;
|
||||
// make a copy of cached tensor
|
||||
RETURN_IF_NOT_OK(GetPhotoTourDataTensor(row_id, &image));
|
||||
(*trow) = TensorRow(row_id, {std::move(image)});
|
||||
trow->setPath({image_names_[row_id / (kPatchNumPerRow * kPatchNumPerCol)],
|
||||
std::to_string(row_id % (kPatchNumPerRow * kPatchNumPerCol))});
|
||||
|
||||
} else {
|
||||
std::shared_ptr<Tensor> image1, image2, matches;
|
||||
// make a copy of cached tensor
|
||||
uint32_t row1 = std::get<0>(matches_[row_id]);
|
||||
uint32_t row2 = std::get<1>(matches_[row_id]);
|
||||
|
||||
RETURN_IF_NOT_OK(GetPhotoTourDataTensor(row1, &image1));
|
||||
RETURN_IF_NOT_OK(GetPhotoTourDataTensor(row2, &image2));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(std::get<2>(matches_[row_id]), &matches));
|
||||
(*trow) = TensorRow(row_id, {std::move(image1), std::move(image2), std::move(matches)});
|
||||
trow->setPath({image_names_[row1 / (kPatchNumPerRow * kPatchNumPerCol)],
|
||||
std::to_string(row1 % (kPatchNumPerRow * kPatchNumPerCol)),
|
||||
image_names_[row2 / (kPatchNumPerRow * kPatchNumPerCol)],
|
||||
std::to_string(row2 % (kPatchNumPerRow * kPatchNumPerCol))});
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PhotoTourOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nPhotoTour directory: " << dataset_dir_ << "\nName: " << name_
|
||||
<< "\nUsage: " << usage_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp.
|
||||
Status PhotoTourOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || labels_.empty()) {
|
||||
if (labels_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("No image found in dataset, please check if image was read successfully.");
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place, "
|
||||
"it must be empty before using GetClassIds.");
|
||||
}
|
||||
}
|
||||
if (train_) {
|
||||
for (size_t i = 0; i < labels_.size(); ++i) {
|
||||
(*cls_ids)[labels_[i]].push_back(i);
|
||||
}
|
||||
for (auto &pair : (*cls_ids)) {
|
||||
pair.second.shrink_to_fit();
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < matches_.size(); ++i) {
|
||||
(*cls_ids)[std::get<2>(matches_[i])].push_back(i);
|
||||
}
|
||||
for (auto &pair : (*cls_ids)) {
|
||||
pair.second.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool PhotoTourOp::EndsWith(const std::string &s, const std::string &sub) {
|
||||
return s.rfind(sub) == (s.length() - sub.length()) ? true : false;
|
||||
}
|
||||
|
||||
Status PhotoTourOp::GetFileContent(const std::string &info_file, std::string *ans) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ans);
|
||||
std::ifstream reader;
|
||||
reader.open(info_file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!reader.fail(), "Invalid file, failed to open PhotoTour info file: " + info_file);
|
||||
(void)reader.seekg(0, std::ios::end);
|
||||
std::size_t size = reader.tellg();
|
||||
(void)reader.seekg(0, std::ios::beg);
|
||||
char *buffer = new char[size + 1];
|
||||
(void)reader.read(buffer, size);
|
||||
buffer[size] = '\0';
|
||||
reader.close();
|
||||
|
||||
// remove \n character in the buffer
|
||||
std::string so(buffer);
|
||||
std::regex pattern("([\\s\\n]+)");
|
||||
std::string fmt = " ";
|
||||
std::string s = std::regex_replace(so, pattern, fmt);
|
||||
|
||||
// remove the head and tail whiteblanks of the s
|
||||
(void)s.erase(0, s.find_first_not_of(" "));
|
||||
(void)s.erase(s.find_last_not_of(" ") + 1);
|
||||
// append one whiteblanks to the end of s
|
||||
s += " ";
|
||||
*ans = s;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourOp::ReadInfoFile(const std::string &data_dir, const std::string &info_file) {
|
||||
std::vector<uint32_t> tmp;
|
||||
labels_.swap(tmp);
|
||||
std::string info_file_path = (Path(data_dir) / Path(info_file)).ToString();
|
||||
std::string s;
|
||||
RETURN_IF_NOT_OK(GetFileContent(info_file_path, &s));
|
||||
auto get_splited_str = [&s](std::size_t pos) {
|
||||
std::string item = s.substr(0, pos);
|
||||
s = s.substr(pos + 1);
|
||||
return item;
|
||||
};
|
||||
enum ColType { ID_3DPOINT, UNKNOWN };
|
||||
std::size_t pos = 0;
|
||||
ColType col_idx = ID_3DPOINT;
|
||||
while ((pos = s.find(" ")) != std::string::npos) {
|
||||
switch (col_idx) {
|
||||
case ID_3DPOINT: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour info file failed: " + info_file_path);
|
||||
int id_3dpoint = std::atoi(item.c_str());
|
||||
labels_.push_back(id_3dpoint);
|
||||
col_idx = UNKNOWN;
|
||||
break;
|
||||
}
|
||||
case UNKNOWN: {
|
||||
std::string item2 = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item2.empty(), "Reading PhotoTour info file failed: " + info_file_path);
|
||||
col_idx = ID_3DPOINT;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourOp::ReadMatchedFile(const std::string &data_dir, const std::string &matches_file) {
|
||||
std::vector<MatchTuple> tmp;
|
||||
matches_.swap(tmp);
|
||||
std::string info_file_path = (Path(data_dir) / Path(matches_file)).ToString();
|
||||
|
||||
std::string s;
|
||||
RETURN_IF_NOT_OK(GetFileContent(info_file_path, &s));
|
||||
|
||||
auto get_splited_str = [&s](std::size_t pos) {
|
||||
std::string item = s.substr(0, pos);
|
||||
s = s.substr(pos + 1);
|
||||
return item;
|
||||
};
|
||||
enum ColType { PATCH_ID1, LABEL1, UNUSED1, PATCH_ID2, LABEL2, UNUSED2, UNUSED3 };
|
||||
uint32_t patch_id1, label1, patch_id2, label2;
|
||||
std::size_t pos = 0;
|
||||
ColType col_idx = PATCH_ID1;
|
||||
while ((pos = s.find(" ")) != std::string::npos) {
|
||||
switch (col_idx) {
|
||||
case PATCH_ID1: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
patch_id1 = std::atoi(item.c_str());
|
||||
col_idx = LABEL1;
|
||||
break;
|
||||
}
|
||||
case LABEL1: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
label1 = std::atoi(item.c_str());
|
||||
col_idx = UNUSED1;
|
||||
break;
|
||||
}
|
||||
case UNUSED1: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
col_idx = PATCH_ID2;
|
||||
break;
|
||||
}
|
||||
case PATCH_ID2: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
patch_id2 = std::atoi(item.c_str());
|
||||
col_idx = LABEL2;
|
||||
break;
|
||||
}
|
||||
case LABEL2: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
label2 = std::atoi(item.c_str());
|
||||
col_idx = UNUSED2;
|
||||
matches_.push_back(std::make_tuple(patch_id1, patch_id2, uint32_t(label1 == label2)));
|
||||
break;
|
||||
}
|
||||
case UNUSED2: {
|
||||
std::string item = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
col_idx = UNUSED3;
|
||||
break;
|
||||
}
|
||||
case UNUSED3: {
|
||||
std::string item2 = get_splited_str(pos);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!item2.empty(), "Reading PhotoTour matched file failed: " + info_file_path);
|
||||
col_idx = PATCH_ID1;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourOp::GetPhotoTourDataTensor(uint32_t index, std::shared_ptr<Tensor> *image_tensor) {
|
||||
RETURN_UNEXPECTED_IF_NULL(image_tensor);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(index < kLens.at(name_),
|
||||
"Index exceeds the maximum count of image, got: " + std::to_string(index));
|
||||
|
||||
int image_id = index / (kPatchNumPerRow * kPatchNumPerCol);
|
||||
int row_in_image = (index % (kPatchNumPerRow * kPatchNumPerCol)) / kPatchNumPerRow;
|
||||
int col_in_image = (index % (kPatchNumPerRow * kPatchNumPerCol)) % kPatchNumPerRow;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(access_mutex_);
|
||||
if (image_bmps_[image_id].empty()) {
|
||||
image_bmps_[image_id] = cv::imread(image_names_[image_id], 0);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t x = col_in_image * kColPerPatch;
|
||||
uint32_t y = row_in_image * kRowPerPatch;
|
||||
|
||||
cv::Rect myROI(x, y, kColPerPatch, kRowPerPatch);
|
||||
|
||||
// Crop the full image to that image contained by the rectangle myROI
|
||||
// Note that this doesn't copy the data
|
||||
cv::Mat croppedRef(image_bmps_[image_id], myROI);
|
||||
cv::Mat cropped;
|
||||
// Copy the data into new matrix
|
||||
croppedRef.copyTo(cropped);
|
||||
|
||||
uchar *uc_img = cropped.data;
|
||||
TensorShape img_tensor_shape = TensorShape({kRowPerPatch, kColPerPatch, 1});
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(img_tensor_shape, data_schema_->Column(0).Type(), uc_img, image_tensor));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read all files in the directory.
|
||||
// @return Status The status code returned.
|
||||
Status PhotoTourOp::PrepareData() {
|
||||
chosen_dataset_folder_path_ = (Path(dataset_dir_) / Path(name_)).ToString();
|
||||
train_ = kTrain.at(usage_);
|
||||
auto real_folder_path = FileUtils::GetRealPath(chosen_dataset_folder_path_.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + chosen_dataset_folder_path_);
|
||||
|
||||
std::vector<cv::String> file_names;
|
||||
cv::glob(real_folder_path.value(), file_names);
|
||||
image_names_.clear();
|
||||
image_bmps_.clear();
|
||||
for (auto &&file_name : file_names) {
|
||||
if (EndsWith(file_name, kImageExt)) {
|
||||
image_names_.push_back(file_name);
|
||||
}
|
||||
}
|
||||
std::sort(image_names_.begin(), image_names_.end());
|
||||
image_bmps_.resize(image_names_.size());
|
||||
RETURN_IF_NOT_OK(ReadInfoFile(real_folder_path.value(), kInfoFile));
|
||||
RETURN_IF_NOT_OK(ReadMatchedFile(real_folder_path.value(), kMatchesFiles));
|
||||
if (train_) {
|
||||
num_rows_ = labels_.size();
|
||||
} else {
|
||||
num_rows_ = matches_.size();
|
||||
}
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, data file may not suitable to read with PhotoTourDataset API."
|
||||
"Check file in directory: " +
|
||||
chosen_dataset_folder_path_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourOp::CountTotalRows(const std::string &dir, const std::string &name, const std::string &usage,
|
||||
int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
if (usage == "train") {
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image1", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image2", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("matches", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
}
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
auto op = std::make_shared<PhotoTourOp>(dir, name, usage, num_workers, op_connect_size, std::move(schema),
|
||||
std::move(sampler));
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
|
||||
if (usage == "train") {
|
||||
*count = op->labels_.size();
|
||||
} else {
|
||||
*count = op->matches_.size();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PHOTO_TOUR_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PHOTO_TOUR_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Forward declares
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
using MatchTuple = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
|
||||
class PhotoTourOp : public MappableLeafOp {
|
||||
public:
|
||||
// Constructor
|
||||
// @param const std::string &datasetDir - Path to the root directory that
|
||||
// contains the dataset.
|
||||
// @param const std::string &name - Name of the dataset to load.
|
||||
// @param const std::string &usage - 'train' or 'test', If 'train', the
|
||||
// generated dataset has one column ["image"], else three columns
|
||||
// ["image1", "image2", "matches"].
|
||||
// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
// @param int32_t queue_size - connector queue size.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the photo tour dataset.
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells PhotoTourOp what to read.
|
||||
PhotoTourOp(const std::string &dataset_dir, const std::string &name, const std::string &usage, int32_t num_workers,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
// Destructor.
|
||||
~PhotoTourOp() = default;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
|
||||
// @param std::map<int32_t, std::vector<int64_t >> *cls_ids - key label, val all ids for this class.
|
||||
// @return Status - The status code returned.
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the PhotoTour dataset.
|
||||
// @param const std::string &dir - path to the PhotoTour directory.
|
||||
// @param const std::string &name - name of the dataset to load.
|
||||
// @param const std::string &usage - 'train' or 'test', If 'train', the
|
||||
// generated dataset has one column ["image"], else three columns
|
||||
// ["image1", "image2", "matches"].
|
||||
// @param int64_t *count - output arg that will hold the minimum of the actual dataset
|
||||
// size and numSamples.
|
||||
// @return Status - The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &name, const std::string &usage,
|
||||
int64_t *count);
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "PhotoTourOp"; }
|
||||
|
||||
private:
|
||||
// Load a tensor row according to the row_id.
|
||||
// @param row_id_type row_id - id for this tensor row.
|
||||
// @param TensorRow *row - load one piece of data into this tensor row.
|
||||
// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row);
|
||||
|
||||
// Judge whether string s ends with string sub.
|
||||
// @param const std::string &s - full string.
|
||||
// @param const std::string &sub - suffix.
|
||||
// @return bool The Result of whether string s ends with string sub.
|
||||
bool EndsWith(const std::string &s, const std::string &sub);
|
||||
|
||||
// Read the content in the given file path.
|
||||
// @param const std::string &info_file - info file name.
|
||||
// @param std::string *ans - store the content of the info file.
|
||||
// @return Status - The status code returned.
|
||||
Status GetFileContent(const std::string &info_file, std::string *ans);
|
||||
|
||||
// Read the meta info for each patch.
|
||||
// @param const std::string &data_dir - data_dir stores the info file.
|
||||
// @param const std::string &info_file - info file name.
|
||||
// @return Status - The status code returned.
|
||||
Status ReadInfoFile(const std::string &data_dir, const std::string &info_file);
|
||||
|
||||
// Read the matches meta info.
|
||||
// @param const std::string &data_dir - data_dir stores the info file.
|
||||
// @param const std::string &matches_file - matches info file name.
|
||||
// @return Status - The status code returned.
|
||||
Status ReadMatchedFile(const std::string &data_dir, const std::string &matches_file);
|
||||
|
||||
// Get one piece of PhotoTour data.
|
||||
// @param uint32_t index - index of data to read.
|
||||
// @param std::shared_ptr<Tensor> *image_tensor - store the indexed data.
|
||||
// @return Status - The status code returned.
|
||||
Status GetPhotoTourDataTensor(uint32_t index, std::shared_ptr<Tensor> *image_tensor);
|
||||
|
||||
// Read all files in the directory.
|
||||
// @return Status - The status code returned.
|
||||
Status PrepareData();
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int64_t buf_cnt_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
const std::string dataset_dir_; // directory of image folder
|
||||
const std::string name_; // dataset name
|
||||
const std::string usage_; // 'train' or 'test'
|
||||
std::string chosen_dataset_folder_path_; // dataset_dir + name : folder
|
||||
bool train_; // whether the usage_ is "train" or not
|
||||
|
||||
std::vector<std::string> image_names_;
|
||||
std::vector<cv::Mat> image_bmps_;
|
||||
|
||||
std::vector<MatchTuple> matches_; // train_ = false, stores the triplets (img1, img2, is_match)
|
||||
std::vector<uint32_t> labels_; // label of i_th patch
|
||||
std::mutex access_mutex_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PHOTO_TOUR_OP_H_
|
|
@ -91,6 +91,7 @@ constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
|||
constexpr char kManifestNode[] = "ManifestDataset";
|
||||
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||
constexpr char kMnistNode[] = "MnistDataset";
|
||||
constexpr char kPhotoTourNode[] = "PhotoTourDataset";
|
||||
constexpr char kPlaces365Node[] = "Places365Dataset";
|
||||
constexpr char kQMnistNode[] = "QMnistDataset";
|
||||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
|
|
|
@ -19,6 +19,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
manifest_node.cc
|
||||
minddata_node.cc
|
||||
mnist_node.cc
|
||||
photo_tour_node.cc
|
||||
places365_node.cc
|
||||
qmnist_node.cc
|
||||
random_node.cc
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/photo_tour_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PhotoTourNode::PhotoTourNode(const std::string &dataset_dir, const std::string &name, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), name_(name), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> PhotoTourNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<PhotoTourNode>(dataset_dir_, name_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void PhotoTourNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(name: " + name_ + ", usage: " + usage_);
|
||||
if (sampler_ != nullptr) {
|
||||
out << ", sampler";
|
||||
}
|
||||
if (cache_ != nullptr) {
|
||||
out << ", cache";
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
|
||||
Status PhotoTourNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("PhotoTourNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("PhotoTourNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("PhotoTourNode", usage_, {"train", "test"}));
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateStringValue("PhotoTourNode", name_,
|
||||
{"notredame", "yosemite", "liberty", "notredame_harris", "yosemite_harris", "liberty_harris"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
if (usage_ == "train") {
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image1", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image2", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("matches", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<PhotoTourOp>(dataset_dir_, name_, usage_, num_workers_, connector_que_size_,
|
||||
std::move(schema), std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node.
|
||||
Status PhotoTourNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size.
|
||||
Status PhotoTourNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(PhotoTourOp::CountTotalRows(dataset_dir_, name_, usage_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PhotoTourNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["name"] = name_;
|
||||
args["usage"] = usage_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PHOTO_TOUR_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PHOTO_TOUR_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class PhotoTourNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
PhotoTourNode(const std::string &dataset_dir, const std::string &name, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~PhotoTourNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kPhotoTourNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[out] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id - The shard ID within num_shards.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed
|
||||
/// up the process of getting dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &GetName() const { return name_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
/// \param[in] sampler - Specify sampler.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string name_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PHOTO_TOUR_NODE_H_
|
|
@ -2843,6 +2843,101 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
|
|||
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class PhotoTourDataset
|
||||
/// \brief A source dataset for reading and parsing PhotoTour dataset.
|
||||
class PhotoTourDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of PhotoTourDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test`.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of PhotoTourDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test`.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of PhotoTourDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test`.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
~PhotoTourDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a PhotoTourDataset.
|
||||
/// \note If usage is 'train', the generated dataset has one column ["image"], else
|
||||
/// three columns ["image1", "image2", "matches"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test` (default="train").
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples
|
||||
/// from the dataset. If sampler is not given, a `RandomSampler` will
|
||||
/// be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current PhotoTourDataset.
|
||||
inline std::shared_ptr<PhotoTourDataset> PhotoTour(
|
||||
const std::string &dataset_dir, const std::string &name, const std::string &usage = "train",
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a PhotoTourDataset.
|
||||
/// \note If usage is 'train', the generated dataset has one column ["image"], else
|
||||
/// three columns ["image1", "image2", "matches"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test`.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current PhotoTourDataset.
|
||||
inline std::shared_ptr<PhotoTourDataset> PhotoTour(const std::string &dataset_dir, const std::string &name,
|
||||
const std::string &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a PhotoTourDataset.
|
||||
/// \note If usage is 'train', the generated dataset has one column ["image"], else
|
||||
/// three columns ["image1", "image2", "matches"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] name Name of the dataset to load, should be one of 'notredame', 'yosemite', 'liberty',
|
||||
/// 'notredame_harris', 'yosemite_harris' or 'liberty_harris'.
|
||||
/// \param[in] usage Part of dataset of PhotoTour, can be `train` or `test`.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current PhotoTourDataset.
|
||||
inline std::shared_ptr<PhotoTourDataset> PhotoTour(const std::string &dataset_dir, const std::string &name,
|
||||
const std::string &usage,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
/// \class Places365Dataset
|
||||
/// \brief A source dataset that reads and parses Places365 dataset.
|
||||
class Places365Dataset : public Dataset {
|
||||
|
|
|
@ -46,6 +46,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class ManifestDataset;
|
||||
friend class MindDataDataset;
|
||||
friend class MnistDataset;
|
||||
friend class PhotoTourDataset;
|
||||
friend class Places365Dataset;
|
||||
friend class QMnistDataset;
|
||||
friend class RandomDataDataset;
|
||||
|
|
|
@ -66,7 +66,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
|
||||
check_photo_tour_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -3464,6 +3465,159 @@ class MnistDataset(MappableDataset):
|
|||
return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class PhotoTourDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the PhotoTour dataset.
|
||||
|
||||
The generated dataset with different usage has different output columns.
|
||||
If train, the generated dataset has one column :py:obj:`[image]`,
|
||||
else three columns :py:obj:`[image1, image2, matches]`.
|
||||
The tensor of column :py:obj:`image`, :py:obj:`image1` and :py:obj:`image2` is of the uint8 type.
|
||||
The tensor of column :py:obj:`matches` is a scalar of the uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
name (str): Name of the dataset to load,
|
||||
should be one of 'notredame', 'yosemite', 'liberty', 'notredame_harris',
|
||||
'yosemite_harris' or 'liberty_harris'.
|
||||
usage (str, optional): Usage of the dataset, can be `train` or `test` (Default=None, will be set to 'train').
|
||||
When usage is `train`, number of samples for each `name` is
|
||||
{'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, 'liberty_harris': 379587,
|
||||
'yosemite_harris': 450912, 'notredame_harris': 325295}.
|
||||
When usage is `test`, will read 100,000 samples for testing.
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If dataset_dir is not exist.
|
||||
ValueError: If usage is not in ["train", "test"].
|
||||
ValueError: If name is not in ["notredame", "yosemite", "liberty",
|
||||
"notredame_harris", "yosemite_harris", "liberty_harris"].
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive. The table
|
||||
below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||
:widths: 64 64 1
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> # Read 3 samples from PhotoTour dataset.
|
||||
>>> dataset = ds.PhotoTourDataset(dataset_dir="/path/to/photo_tour_dataset_directory",
|
||||
... name='liberty', usage='train', num_samples=3)
|
||||
>>>
|
||||
>>> # In PhotoTourDataset dataset, if usage is 'train', each dictionary has key "image",
|
||||
>>> # else has keys "image1" "image2" and "matches".
|
||||
|
||||
About PhotoTour dataset:
|
||||
|
||||
The data is taken from Photo Tourism reconstructions from Trevi Fountain (Rome), Notre Dame (Paris) and Half
|
||||
Dome (Yosemite). Each dataset consists of a series of corresponding patches, which are obtained by projecting
|
||||
3D points from Photo Tourism reconstructions back into the original images.
|
||||
|
||||
The dataset consists of 1024 x 1024 bitmap (.bmp) images, each containing a 16 x 16 array of image patches.
|
||||
Each patch is sampled as 64 x 64 grayscale, with a canonical scale and orientation. For details of how the scale
|
||||
and orientation is established, please see the paper. An associated metadata file info.txt contains the match
|
||||
information. Each row of info.txt corresponds to a separate patch, with the patches ordered from left to right and
|
||||
top to bottom in each bitmap image. The first number on each row of info.txt is the 3D point ID from which that
|
||||
patch was sampled -- patches with the same 3D point ID are projected from the same 3D point (into different images).
|
||||
The second number in info.txt corresponds to the image from which the patch was sampled, and is not used at present.
|
||||
|
||||
You can unzip the original PhotoTour dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
.
|
||||
└── photo_tour_dataset_directory
|
||||
├── liberty/
|
||||
│ ├── info.txt // two columns: 3D_point_ID, unused
|
||||
│ ├── m50_100000_100000_0.txt // seven columns: patch_ID1, 3D_point_ID1, unused1,
|
||||
│ │ // patch_ID2, 3D_point_ID2, unused2, unused3
|
||||
│ ├── patches0000.bmp // 1024*1024 pixels, with 16 * 16 patches.
|
||||
│ ├── patches0001.bmp
|
||||
│ ├── ...
|
||||
├── yosemite/
|
||||
│ ├── ...
|
||||
├── notredame/
|
||||
│ ├── ...
|
||||
├── liberty_harris/
|
||||
│ ├── ...
|
||||
├── yosemite_harris/
|
||||
│ ├── ...
|
||||
├── notredame_harris/
|
||||
│ ├── ...
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@INPROCEEDINGS{4269996,
|
||||
author={Winder, Simon A. J. and Brown, Matthew},
|
||||
booktitle={2007 IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
title={Learning Local Image Descriptors},
|
||||
year={2007},
|
||||
volume={},
|
||||
number={},
|
||||
pages={1-8},
|
||||
doi={10.1109/CVPR.2007.382971}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_photo_tour_dataset
|
||||
def __init__(self, dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.name = name
|
||||
self.usage = replace_none(usage, "train")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.PhotoTourNode(self.dataset_dir, self.name, self.usage, self.sampler)
|
||||
|
||||
|
||||
class Places365Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the Places365 dataset.
|
||||
|
|
|
@ -92,6 +92,36 @@ def check_mnist_cifar_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_photo_tour_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test"], "usage")
|
||||
name = param_dict.get('name')
|
||||
check_valid_str(name, ["notredame", "yosemite", "liberty", "notredame_harris",
|
||||
"yosemite_harris", "liberty_harris"], "name")
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_places365_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(Places365Dataset)."""
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_minddata_test.cc
|
||||
c_api_dataset_ops_test.cc
|
||||
c_api_dataset_photo_tour_test.cc
|
||||
c_api_dataset_places365_test.cc
|
||||
c_api_dataset_qmnist_test.cc
|
||||
c_api_dataset_randomdata_test.cc
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: PhotoTourTrainDataset.
|
||||
/// Description: test basic usage of PhotoTourTrainDataset.
|
||||
/// Expectation: get correct number of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourTrainDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourTrainDataset.";
|
||||
|
||||
// Create a PhotoTour Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "train", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourTestDataset.
|
||||
/// Description: test basic usage of PhotoTourTestDataset.
|
||||
/// Expectation: get correct number of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourTestDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourTestDataset.";
|
||||
|
||||
// Create a PhotoTour Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "test", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image1"), row.end());
|
||||
EXPECT_NE(row.find("image2"), row.end());
|
||||
EXPECT_NE(row.find("matches"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image1"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourTrainDatasetWithPipeline.
|
||||
/// Description: test usage of PhotoTourTrainDataset with pipeline.
|
||||
/// Expectation: get correct number of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourTrainDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourTrainDatasetWithPipeline.";
|
||||
|
||||
// Create two PhotoTour Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds1 = PhotoTour(folder_path, "liberty", "train", std::make_shared<RandomSampler>(false, 10));
|
||||
std::shared_ptr<Dataset> ds2 = PhotoTour(folder_path, "liberty", "train", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 40);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourTrainDatasetSize.
|
||||
/// Description: test usage of get the size of PhotoTourTrainDataset.
|
||||
/// Expectation: get correct number of data.
|
||||
TEST_F(MindDataTestPipeline, TestGetPhotoTourTrainDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetPhotoTourTrainDatasetSize.";
|
||||
|
||||
// Create a PhotoTour Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 100);
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourTrainDatasetGetters.
|
||||
/// Description: test usage of getters PhotoTourTrainDataset.
|
||||
/// Expectation: get correct number of data and correct tensor shape.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourTrainDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourTrainDatasetGetters.";
|
||||
|
||||
// Create a PhotoTour Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 100);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 1);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(shapes.size(), 1);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<64,64,1>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 100);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 100);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 100);
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourDatasetFail.
|
||||
/// Description: test failure of PhotoTourDataset.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourDatasetFail.";
|
||||
|
||||
// Create a PhotoTour Dataset
|
||||
std::shared_ptr<Dataset> ds = PhotoTour("", "", "train", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid PhotoTour input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourDatasetWithInvalidUsageFail.
|
||||
/// Description: test failure of PhotoTourDataset with invalid usage.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourDatasetWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourDatasetWithInvalidUsageFail.";
|
||||
|
||||
// Create a PhotoTour Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "validation");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid PhotoTour input, validation is not a valid usage
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: PhotoTourDatasetWithNullSamplerFail.
|
||||
/// Description: test failure of PhotoTourDataset with null sampler.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestPhotoTourDatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhotoTourDatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a PhotoTour Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPhotoTourData";
|
||||
std::shared_ptr<Dataset> ds = PhotoTour(folder_path, "liberty", "train", nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid PhotoTour input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 0
|
||||
1 0
|
||||
2 0
|
||||
2 0
|
||||
3 0
|
||||
3 0
|
||||
4 0
|
||||
4 0
|
||||
5 0
|
||||
5 0
|
||||
6 0
|
||||
6 0
|
||||
7 0
|
||||
7 0
|
||||
7 0
|
||||
8 0
|
||||
8 0
|
||||
8 0
|
||||
9 0
|
||||
9 0
|
||||
10 0
|
||||
10 0
|
||||
11 0
|
||||
11 0
|
||||
12 0
|
||||
12 0
|
||||
12 0
|
||||
13 0
|
||||
13 0
|
||||
13 0
|
||||
14 0
|
||||
14 0
|
||||
14 0
|
||||
15 0
|
||||
15 0
|
||||
15 0
|
||||
16 0
|
||||
16 0
|
||||
16 0
|
||||
17 0
|
||||
17 0
|
||||
17 0
|
||||
18 0
|
||||
18 0
|
||||
18 0
|
||||
19 0
|
||||
19 0
|
||||
19 0
|
||||
20 0
|
||||
20 0
|
||||
21 0
|
||||
21 0
|
||||
22 0
|
||||
22 0
|
||||
23 0
|
||||
23 0
|
||||
24 0
|
||||
24 0
|
||||
25 0
|
||||
25 0
|
||||
26 0
|
||||
26 0
|
||||
26 0
|
||||
27 0
|
||||
27 0
|
||||
27 0
|
||||
28 0
|
||||
28 0
|
||||
28 0
|
||||
29 0
|
||||
29 0
|
||||
29 0
|
||||
30 0
|
||||
30 0
|
||||
31 0
|
||||
31 0
|
||||
31 0
|
||||
32 0
|
||||
32 0
|
||||
32 0
|
||||
32 0
|
||||
33 0
|
||||
33 0
|
||||
34 0
|
||||
34 0
|
||||
34 0
|
||||
35 0
|
||||
35 0
|
||||
36 0
|
||||
36 0
|
||||
37 0
|
||||
37 0
|
||||
38 0
|
||||
38 0
|
||||
39 0
|
||||
39 0
|
||||
39 0
|
|
@ -0,0 +1,16 @@
|
|||
0 3960 0 1 3960 0 0
|
||||
2 67354 0 3 159829 0 0
|
||||
4 117877 0 5 117877 0 0
|
||||
6 68371 0 7 68371 0 0
|
||||
8 79408 0 9 79408 0 0
|
||||
10 3705 0 11 3705 0 0
|
||||
2 22229 0 4 102831 0 0
|
||||
4 103585 0 5 103585 0 0
|
||||
12 50522 0 13 50522 0 0
|
||||
14 83120 0 15 110333 0 0
|
||||
14 20882 0 18 20882 0 0
|
||||
21 62246 0 33 62246 0 0
|
||||
22 87077 0 21 87077 0 0
|
||||
23 81798 0 2 143036 0 0
|
||||
23 6295 0 34 55263 0 0
|
||||
12 19226 0 24 19226 0 0
|
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
|
@ -0,0 +1,324 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test PhotoTour dataset operator
|
||||
"""
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testPhotoTourData"
|
||||
NAME = 'liberty'
|
||||
LEN = 100
|
||||
|
||||
|
||||
def load_photo_tour_dataset(path, name):
|
||||
"""
|
||||
Feature: load_photo_tour_dataset.
|
||||
Description: load photo tour.
|
||||
Expectation: get data of photo tour dataset.
|
||||
"""
|
||||
def pil2array(img: Image.Image):
|
||||
"""
|
||||
Convert PIL image type to numpy 2D array
|
||||
"""
|
||||
return np.array(img.getdata(), dtype=np.uint8).reshape((64, 64, 1))
|
||||
|
||||
def find_files(data_dir: str, image_ext_: str):
|
||||
"""
|
||||
Return a list with the file names of the images containing the patches
|
||||
"""
|
||||
files = []
|
||||
# find those files with the specified extension
|
||||
for file_dir in os.listdir(data_dir):
|
||||
if file_dir.endswith(image_ext_):
|
||||
files.append(os.path.join(data_dir, file_dir))
|
||||
return sorted(files) # sort files in ascend order to keep relations
|
||||
|
||||
patches = []
|
||||
list_files = find_files(os.path.realpath(os.path.join(path, name)), 'bmp')
|
||||
idx = 0
|
||||
for fpath in list_files:
|
||||
img = Image.open(fpath)
|
||||
for y in range(0, 1024, 64):
|
||||
for x in range(0, 1024, 64):
|
||||
patch = img.crop((x, y, x + 64, y + 64))
|
||||
patches.append(pil2array(patch))
|
||||
idx += 1
|
||||
if idx > LEN:
|
||||
break
|
||||
if idx > LEN:
|
||||
break
|
||||
matches_path = os.path.join(os.path.realpath(os.path.join(path, name)), 'm50_100000_100000_0.txt')
|
||||
matches = []
|
||||
with open(matches_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line_split = line.split()
|
||||
matches.append([int(line_split[0]), int(line_split[3]),
|
||||
int(line_split[1] == line_split[4])])
|
||||
return patches, matches
|
||||
|
||||
|
||||
def visualize_dataset(images1, images2, matches):
|
||||
"""
|
||||
Feature: visualize_dataset.
|
||||
Description: visualize photo tour dataset.
|
||||
Expectation: plot images.
|
||||
"""
|
||||
num_samples = len(images1)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images1[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(matches[i])
|
||||
num_samples = len(images2)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(2, num_samples, i + 1)
|
||||
plt.imshow(images2[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(matches[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_photo_tour_content_check():
|
||||
"""
|
||||
Feature: test_photo_tour_content_check.
|
||||
Description: validate PhotoTourDataset image readings.
|
||||
Expectation: get correct number of data and correct content.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset Op with content check")
|
||||
data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10, shuffle=False)
|
||||
images, matches = load_photo_tour_dataset(DATA_DIR, NAME)
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image1" "image2" and "matches"
|
||||
|
||||
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(data["image1"], images[matches[i][0]])
|
||||
np.testing.assert_array_equal(data["image2"], images[matches[i][1]])
|
||||
np.testing.assert_array_equal(data["matches"], matches[i][2])
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_photo_tour_basic():
|
||||
"""
|
||||
Feature: test_photo_tour_basic.
|
||||
Description: test basic usage of PhotoTourDataset.
|
||||
Expectation: get correct number of data.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test')
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 16
|
||||
|
||||
# case 2: test num_samples
|
||||
data2 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 10
|
||||
|
||||
# case 3: test repeat
|
||||
data3 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=5)
|
||||
data3 = data3.repeat(5)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 25
|
||||
|
||||
# case 4: test batch with drop_remainder=False
|
||||
data4 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
|
||||
assert data4.get_dataset_size() == 10
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
|
||||
assert data4.get_dataset_size() == 2
|
||||
assert data4.get_batch_size() == 7
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 2
|
||||
|
||||
# case 5: test batch with drop_remainder=True
|
||||
data5 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
|
||||
assert data5.get_dataset_size() == 10
|
||||
assert data5.get_batch_size() == 1
|
||||
data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data5.get_dataset_size() == 1
|
||||
assert data5.get_batch_size() == 7
|
||||
num_iter5 = 0
|
||||
for _ in data5.create_dict_iterator(num_epochs=1):
|
||||
num_iter5 += 1
|
||||
assert num_iter5 == 1
|
||||
|
||||
# case 6: test get_col_names
|
||||
data6 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
|
||||
assert data6.get_col_names() == ['image1', 'image2', 'matches']
|
||||
|
||||
|
||||
def test_photo_tour_pk_sampler():
|
||||
"""
|
||||
Feature: test_photo_tour_pk_sampler.
|
||||
Description: test usage of PhotoTourDataset with PKSampler.
|
||||
Expectation: get correct number of data.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset Op with PKSampler")
|
||||
golden = [0, 0, 0, 1, 1, 1]
|
||||
sampler = ds.PKSampler(3)
|
||||
data = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=sampler)
|
||||
num_iter = 0
|
||||
matches_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
matches_list.append(item["matches"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(golden, matches_list)
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_photo_tour_sequential_sampler():
|
||||
"""
|
||||
Feature: test_photo_tour_sequential_sampler.
|
||||
Description: test usage of PhotoTourDataset with SequentialSampler.
|
||||
Expectation: get correct number of data.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset Op with SequentialSampler")
|
||||
num_samples = 5
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=sampler)
|
||||
data2 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_samples=num_samples)
|
||||
matches_list1, matches_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
|
||||
matches_list1.append(item1["matches"].asnumpy())
|
||||
matches_list2.append(item2["matches"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(matches_list1, matches_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_photo_tour_exception():
|
||||
"""
|
||||
Feature: test_photo_tour_exception.
|
||||
Description: test error cases for PhotoTourDataset.
|
||||
Expectation: raise exception.
|
||||
"""
|
||||
logger.info("Test error cases for PhotoTourDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=2, shard_id="0")
|
||||
|
||||
|
||||
def test_photo_tour_visualize(plot=False):
|
||||
"""
|
||||
Feature: test_photo_tour_visualize.
|
||||
Description: visualize PhotoTourDataset results.
|
||||
Expectation: get correct number of data and plot them.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset visualization")
|
||||
|
||||
data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list1, image_list2, matches_list = [], [], []
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image1 = item["image1"]
|
||||
image2 = item["image2"]
|
||||
matches = item["matches"]
|
||||
image_list1.append(image1)
|
||||
image_list2.append(image2)
|
||||
matches_list.append("matches {}".format(matches))
|
||||
assert isinstance(image1, np.ndarray)
|
||||
assert isinstance(image2, np.ndarray)
|
||||
assert image1.shape == (64, 64, 1)
|
||||
assert image1.dtype == np.uint8
|
||||
assert image2.shape == (64, 64, 1)
|
||||
assert image2.dtype == np.uint8
|
||||
assert matches.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
if plot:
|
||||
visualize_dataset(image_list1, image_list2, matches_list)
|
||||
|
||||
|
||||
def test_photo_tour_usage():
|
||||
"""
|
||||
Feature: test_photo_tour_usage.
|
||||
Description: validate PhotoTourDataset image readings.
|
||||
Expectation: get correct number of data.
|
||||
"""
|
||||
logger.info("Test PhotoTourDataset usage flag")
|
||||
|
||||
def test_config(photo_tour_path, name, usage):
|
||||
try:
|
||||
data = ds.PhotoTourDataset(photo_tour_path, name, usage, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config(DATA_DIR, NAME, "test") == 16
|
||||
assert test_config(DATA_DIR, NAME, "train") == LEN
|
||||
assert "usage is not within the valid set of ['train', 'test']" in test_config(DATA_DIR, NAME, "invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(DATA_DIR, NAME, ["list"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_photo_tour_content_check()
|
||||
test_photo_tour_basic()
|
||||
test_photo_tour_pk_sampler()
|
||||
test_photo_tour_sequential_sampler()
|
||||
test_photo_tour_exception()
|
||||
test_photo_tour_visualize(plot=True)
|
||||
test_photo_tour_usage()
|
Loading…
Reference in New Issue