forked from mindspore-Ecosystem/mindspore
!17486 [assistant][ops] Add new dataset operator Flickr.
Merge pull request !17486 from Rainfor/wangkc
This commit is contained in:
commit
f4e07c0dce
|
@ -94,6 +94,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
@ -927,6 +928,32 @@ CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds =
|
||||
std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds =
|
||||
std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds =
|
||||
std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler,
|
||||
const std::set<std::vector<char>> &extensions,
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
@ -122,6 +123,17 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FlickrNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, py::handle sampler) {
|
||||
auto flickr =
|
||||
std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(flickr->ValidateParams());
|
||||
return flickr;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||
|
|
|
@ -16,6 +16,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
album_op.cc
|
||||
mappable_leaf_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
flickr_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/flickr_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
FlickrOp::FlickrOp(int32_t num_workers, const std::string &dataset_dir, const std::string &file_path, bool decode,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(dataset_dir),
|
||||
file_path_(file_path),
|
||||
decode_(decode),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status FlickrOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FlickrOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 2 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseFlickrData()); // Parse Flickr data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image, annotations) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status FlickrOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
std::pair<std::string, std::vector<std::string>> data = image_annotation_pairs_[static_cast<size_t>(row_id)];
|
||||
std::shared_ptr<Tensor> image;
|
||||
std::shared_ptr<Tensor> annotations;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.first, &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(data.second, &annotations));
|
||||
|
||||
if (decode_ == true) {
|
||||
Status rc = Decode(image, &image);
|
||||
if (rc.IsError()) {
|
||||
std::string err = "Invalid data, failed to decode image: " + data.first;
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(annotations)});
|
||||
trow->setPath({data.first, file_path_});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void FlickrOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows:" << num_rows_ << "\nFlickr DatasetDir: " << dataset_dir_
|
||||
<< "\nAnnotationFile: " << file_path_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status FlickrOp::ParseFlickrData() {
|
||||
std::ifstream file_handle(file_path_);
|
||||
if (!file_handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Flickr annotation file: " + file_path_);
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int32_t flag_idx;
|
||||
std::string sub_str_flag = "\t";
|
||||
std::string image_file_path;
|
||||
std::string image_name;
|
||||
std::map<std::string, std::vector<std::string>> image_annotation_map_;
|
||||
Path dataset_dir(dataset_dir_);
|
||||
while (getline(file_handle, line)) {
|
||||
try {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
flag_idx = line.find_first_of(sub_str_flag);
|
||||
image_name = line.substr(0, flag_idx - 2); // -2 because "#[0-4]\t"
|
||||
if (image_name.empty()) {
|
||||
file_handle.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, image_name is not found in Flickr annotation file: " + file_path_ +
|
||||
"; line: " + line);
|
||||
}
|
||||
|
||||
image_file_path = (dataset_dir / image_name).toString();
|
||||
std::string annotation = line.substr(flag_idx + 1);
|
||||
if (annotation.empty()) {
|
||||
file_handle.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, annotation is not found in Flickr annotation file: " + file_path_ +
|
||||
"; line: " + line);
|
||||
}
|
||||
|
||||
bool valid = false;
|
||||
RETURN_IF_NOT_OK(CheckImageType(image_file_path, &valid));
|
||||
if (!valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
image_annotation_map_[image_file_path].emplace_back(annotation);
|
||||
} catch (const std::exception &err) {
|
||||
file_handle.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Flickr annotation file: " + file_path_);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto item : image_annotation_map_) {
|
||||
image_annotation_pairs_.emplace_back(std::make_pair(item.first, item.second));
|
||||
}
|
||||
|
||||
file_handle.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only support JPEG/PNG/GIF/BMP
|
||||
// Optimization: Could take in a tensor
|
||||
// This function does not return status because we want to just skip bad input, not crash
|
||||
Status FlickrOp::CheckImageType(const std::string &file_name, bool *valid) {
|
||||
std::ifstream file_handle;
|
||||
constexpr int read_num = 3;
|
||||
*valid = false;
|
||||
file_handle.open(file_name, std::ios::binary | std::ios::in);
|
||||
if (!file_handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open image file: " + file_name);
|
||||
}
|
||||
unsigned char file_type[read_num];
|
||||
(void)file_handle.read(reinterpret_cast<char *>(file_type), read_num);
|
||||
|
||||
if (file_handle.fail()) {
|
||||
file_handle.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to read image file: " + file_name);
|
||||
}
|
||||
file_handle.close();
|
||||
if (file_type[0] == 0xff && file_type[1] == 0xd8 && file_type[2] == 0xff) {
|
||||
// Normal JPEGs start with \xff\xd8\xff\xe0
|
||||
// JPEG with EXIF stats with \xff\xd8\xff\xe1
|
||||
// Use \xff\xd8\xff to cover both.
|
||||
*valid = true;
|
||||
} else if (file_type[0] == 0x89 && file_type[1] == 0x50 && file_type[2] == 0x4e) {
|
||||
// It's a PNG
|
||||
*valid = true;
|
||||
} else if (file_type[0] == 0x47 && file_type[1] == 0x49 && file_type[2] == 0x46) {
|
||||
// It's a GIF
|
||||
*valid = true;
|
||||
} else if (file_type[0] == 0x42 && file_type[1] == 0x4d) {
|
||||
// It's a BMP
|
||||
*valid = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FlickrOp::CountDatasetInfo() {
|
||||
num_rows_ = static_cast<int64_t>(image_annotation_pairs_.size());
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API FlickrDataset. Please check file path or dataset API.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FlickrOp::CountTotalRows(const std::string &dir, const std::string &file, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseFlickrData()
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto new_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
// build a new unique schema object
|
||||
auto new_schema = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
new_schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
std::shared_ptr<FlickrOp> op = std::make_shared<FlickrOp>(num_workers, dir, file, false, op_connect_size,
|
||||
std::move(new_schema), std::move(new_sampler));
|
||||
|
||||
RETURN_IF_NOT_OK(op->ParseFlickrData());
|
||||
*count = static_cast<int64_t>(op->image_annotation_pairs_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FlickrOp::ComputeColMap() {
|
||||
// Set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->column(i).name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FLICKR_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FLICKR_OP_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class FlickrOp : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] int32_t num_workers - Num of workers reading images in parallel
|
||||
/// \param[in] std::string dataset_dir - dir directory of Flickr dataset
|
||||
/// \param[in] std::string annotation_file - dir directory of annotation file
|
||||
/// \param[in] int32_t queue_size - connector queue size
|
||||
/// \param[in] std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
FlickrOp(int32_t num_workers, const std::string &dataset_dir, const std::string &annotation_file, bool decode,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~FlickrOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param[out] out
|
||||
/// \param[in] show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Function to count the number of samples in the Flickr dataset
|
||||
/// \param[in] dir - path to the Flickr directory
|
||||
/// \param[in] file - path to the annotation file
|
||||
/// \param[out] count - output arg that will hold the actual dataset size
|
||||
/// \return Status - The status code returned
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &file, int64_t *count);
|
||||
|
||||
/// \brief Op name getter
|
||||
/// \return Name of the current Op
|
||||
std::string Name() const override { return "FlickrOp"; }
|
||||
|
||||
private:
|
||||
/// \brief Load a tensor row according to a pair
|
||||
/// \param[in] uint64_t index - index need to load
|
||||
/// \param[out] TensorRow row - image & annotation read into this tensor row
|
||||
/// \return Status - The status code returned
|
||||
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \return Status - The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Parse Flickr data
|
||||
/// \return Status - The status code returned
|
||||
Status ParseFlickrData();
|
||||
|
||||
/// \brief Check if image ia valid.Only support JPEG/PNG/GIF/BMP
|
||||
/// \param[in] std::string file_name - image file name need to be checked
|
||||
/// \param[out] bool valid - whether the image type is valid
|
||||
/// \return Status - The status code returned
|
||||
Status CheckImageType(const std::string &file_name, bool *valid);
|
||||
|
||||
/// \brief Count annotation index,num rows and num samples
|
||||
/// \return Status - The status code returned
|
||||
Status CountDatasetInfo();
|
||||
|
||||
/// \brief Private function for computing the assignment of the column name map.
|
||||
/// \return Status - The status code returned
|
||||
Status ComputeColMap() override;
|
||||
|
||||
std::string dataset_dir_;
|
||||
std::string file_path_;
|
||||
bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<std::string>>> image_annotation_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FLICKR_OP_H_
|
|
@ -81,6 +81,7 @@ constexpr char kCifar10Node[] = "Cifar10Dataset";
|
|||
constexpr char kCLUENode[] = "CLUEDataset";
|
||||
constexpr char kCocoNode[] = "CocoDataset";
|
||||
constexpr char kCSVNode[] = "CSVDataset";
|
||||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
constexpr char kManifestNode[] = "ManifestDataset";
|
||||
|
|
|
@ -10,6 +10,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
clue_node.cc
|
||||
coco_node.cc
|
||||
csv_node.cc
|
||||
flickr_node.cc
|
||||
image_folder_node.cc
|
||||
manifest_node.cc
|
||||
minddata_node.cc
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/flickr_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for FlickrNode
|
||||
FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
annotation_file_(annotation_file),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> FlickrNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void FlickrNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(dataset dir:" + dataset_dir_;
|
||||
out << ", annotation file:" + annotation_file_;
|
||||
if (sampler_ != nullptr) {
|
||||
out << ", sampler";
|
||||
}
|
||||
if (cache_ != nullptr) {
|
||||
out << ", cache";
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
|
||||
Status FlickrNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrNode", dataset_dir_));
|
||||
|
||||
if (annotation_file_.empty()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file is not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
|
||||
for (char c : annotation_file_) {
|
||||
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
|
||||
if (p != forbidden_symbols.end()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
Path annotation_file(annotation_file_);
|
||||
if (!annotation_file.Exists()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] is invalid or not exist.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrNode", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build FlickrOp for Flickr
|
||||
Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||
flickr_op->set_total_repeats(GetTotalRepeats());
|
||||
flickr_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(flickr_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status FlickrNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status FlickrNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(FlickrOp::CountTotalRows(dataset_dir_, annotation_file_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FlickrNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["annotation_file"] = annotation_file_;
|
||||
args["decode"] = decode_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FLICKR_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FLICKR_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FlickrNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~FlickrNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kFlickrNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param[out] out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param[out] node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Getter functions
|
||||
const std::string &AnnotationFile() const { return annotation_file_; }
|
||||
|
||||
/// \brief Getter functions
|
||||
bool Decode() const { return decode_; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FLICKR_NODE_H_
|
|
@ -1091,6 +1091,64 @@ inline std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_f
|
|||
cache);
|
||||
}
|
||||
|
||||
class FlickrDataset : public Dataset {
|
||||
public:
|
||||
explicit FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
explicit FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
explicit FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
~FlickrDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a FlickrDataset
|
||||
/// \notes The generated dataset has two columns ["image", "annotation"]
|
||||
/// \param[in] dataset_dir The dataset dir to be read
|
||||
/// \param[in] annotation_file The annotation file to be read
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FlickrDataset
|
||||
inline std::shared_ptr<FlickrDataset> Flickr(
|
||||
const std::string &dataset_dir, const std::string &annotation_file, bool decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FlickrDataset
|
||||
/// \notes The generated dataset has two columns ["image", "annotation"]
|
||||
/// \param[in] dataset_dir The dataset dir to be read
|
||||
/// \param[in] annotation_file The annotation file to be read
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FlickrDataset
|
||||
inline std::shared_ptr<FlickrDataset> Flickr(const std::string &dataset_dir, const std::string &annotation_file,
|
||||
bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FlickrDataset
|
||||
/// \notes The generated dataset has two columns ["image", "annotation"]
|
||||
/// \param[in] dataset_dir The dataset dir to be read
|
||||
/// \param[in] annotation_file The annotation file to be read
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FlickrDataset
|
||||
inline std::shared_ptr<FlickrDataset> Flickr(const std::string &dataset_dir, const std::string &annotation_file,
|
||||
bool decode, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
|
||||
cache);
|
||||
}
|
||||
|
||||
class ImageFolderDataset : public Dataset {
|
||||
public:
|
||||
explicit ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
|
||||
|
|
|
@ -37,6 +37,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class CLUEDataset;
|
||||
friend class CocoDataset;
|
||||
friend class CSVDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class ManifestDataset;
|
||||
friend class MindDataDataset;
|
||||
|
|
|
@ -63,7 +63,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_celebadataset, check_minddataset, check_generatordataset, check_sync_wait, check_zip_dataset, \
|
||||
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
|
||||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_sb_dataset
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -5675,6 +5676,154 @@ class PaddedDataset(GeneratorDataset):
|
|||
self.padded_samples = padded_samples
|
||||
|
||||
|
||||
class FlickrDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Flickr8k and Flickr30k dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, annotation]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`annotation` is a tensor which contains 5 annotations string,
|
||||
such as ["a", "b", "c", "d", "e"].
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
annotation_file (str): Path to the root directory that contains the annotation.
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
|
||||
order behavior shown in the table).
|
||||
decode (bool, optional): Decode the images after reading (default=False).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None). When this argument is specified, `num_samples` reflects
|
||||
the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir is not valid or does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If dataset_dir is not exist.
|
||||
ValueError: If annotation_file is not exist.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> flickr_dataset_dir = "/path/to/flickr_dataset_directory"
|
||||
>>> annotation_file = "/path/to/flickr_annotation_file"
|
||||
>>>
|
||||
>>> # 1) Get all samples from FLICKR dataset in sequence
|
||||
>>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
|
||||
... annotation_file=annotation_file,
|
||||
... shuffle=False)
|
||||
>>>
|
||||
>>> # 2) Randomly select 350 samples from FLICKR dataset
|
||||
>>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
|
||||
... annotation_file=annotation_file,
|
||||
... num_samples=350,
|
||||
... shuffle=True)
|
||||
>>>
|
||||
>>> # 3) Get samples from FLICKR dataset for shard 0 in a 2-way distributed training
|
||||
>>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
|
||||
... annotation_file=annotation_file,
|
||||
... num_shards=2,
|
||||
... shard_id=0)
|
||||
>>>
|
||||
>>> # In FLICKR dataset, each dictionary has keys "image" and "annotation"
|
||||
|
||||
About Flickr8k dataset:
|
||||
| The Flickr8k dataset consists of 8092 colour images. There are 40460 annotations in the Flickr8k.token.txt,
|
||||
each image has 5 annotations.
|
||||
|
||||
| You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||
| .
|
||||
| └── Flickr8k
|
||||
| ├── Flickr8k_Dataset
|
||||
| | ├── 1000268201_693b08cb0e.jpg
|
||||
| | ├── 1001773457_577c3a7d70.jpg
|
||||
| | ├── ...
|
||||
| └── Flickr8k.token.txt
|
||||
|
||||
.. code-block::
|
||||
|
||||
M. Hodosh, P. Young and J. Hockenmaier (2013)
|
||||
"Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics"
|
||||
Journal of Artificial Intellegence Research, Volume 47, pages 853-899
|
||||
http://www.jair.org/papers/paper3994.html
|
||||
|
||||
About Flickr30k dataset:
|
||||
| The Flickr30k dataset consists of 31783 colour images. There are 158915 annotations in
|
||||
the results_20130124.token, each image has 5 annotations.
|
||||
|
||||
| You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||
| .
|
||||
| └── Flickr30k
|
||||
| ├── flickr30k-images
|
||||
| | ├── 1000092795.jpg
|
||||
| | ├── 10002456.jpg
|
||||
| | ├── ...
|
||||
| └── results_20130124.token
|
||||
|
||||
.. code-block::
|
||||
|
||||
P. Young, A. Lai, M. Hodosh, and J. Hockenmaier.
|
||||
From image description to visual denotations:
|
||||
New similarity metrics for semantic inference over event descriptions.
|
||||
Transactions of the Association for Computational Linguistics (to appear).
|
||||
"""
|
||||
|
||||
@check_flickr_dataset
|
||||
def __init__(self, dataset_dir, annotation_file, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
decode=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.annotation_file = annotation_file
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.FlickrNode(self.dataset_dir, self.annotation_file, self.decode, self.sampler)
|
||||
|
||||
|
||||
class SBDataset(GeneratorDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Semantic Boundaries Dataset.
|
||||
|
|
|
@ -1332,6 +1332,34 @@ def check_to_device_send(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_flickr_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
annotation_file = param_dict.get('annotation_file')
|
||||
check_dir(dataset_dir)
|
||||
check_file(annotation_file)
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_sb_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_coco_test.cc
|
||||
c_api_dataset_config_test.cc
|
||||
c_api_dataset_csv_test.cc
|
||||
c_api_dataset_flickr_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_minddata_test.cc
|
||||
|
|
|
@ -0,0 +1,308 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrBasic.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
|
||||
// Create a Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds = Flickr(dataset_path, file_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrBasicWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrBasicWithPipeline.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
|
||||
// Create two Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds1 = Flickr(dataset_path, file_path);
|
||||
std::shared_ptr<Dataset> ds2 = Flickr(dataset_path, file_path);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 3;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrGetters.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path1 = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
std::string file_path2 = datasets_root_path_ + "/testFlickrData/flickr30k/test2.token";
|
||||
|
||||
// Create a Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds1 = Flickr(dataset_path, file_path1);
|
||||
std::shared_ptr<Dataset> ds2 = Flickr(dataset_path, file_path2);
|
||||
std::vector<std::string> column_names = {"image", "annotation"};
|
||||
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_EQ(ds1->GetDatasetSize(), 2);
|
||||
EXPECT_EQ(ds1->GetColumnNames(), column_names);
|
||||
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_EQ(ds2->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ds2->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrAnnotations) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrGetters.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test3.token";
|
||||
std::shared_ptr<Dataset> ds = Flickr(dataset_path, file_path);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::shared_ptr<Tensor> a_expect_item;
|
||||
std::vector<std::string> annotation_arr;
|
||||
annotation_arr.emplace_back("This is a banana.");
|
||||
annotation_arr.emplace_back("This is a yellow banana.");
|
||||
annotation_arr.emplace_back("This is a banana on the table.");
|
||||
annotation_arr.emplace_back("The banana is yellow.");
|
||||
annotation_arr.emplace_back("The banana is very big.");
|
||||
|
||||
ASSERT_OK(Tensor::CreateFromVector(annotation_arr, &a_expect_item));
|
||||
mindspore::MSTensor expect_item = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(a_expect_item));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto annotation = row["annotation"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
MS_LOG(INFO) << "Tensor annotation shape: " << annotation.Shape();
|
||||
|
||||
EXPECT_MSTENSOR_EQ(annotation, expect_item);
|
||||
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 1);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrDecode) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrDecode.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
// Create a Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds = Flickr(dataset_path, file_path, true, std::make_shared<RandomSampler>());
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto shape = image.Shape();
|
||||
MS_LOG(INFO) << "Tensor image shape size: " << shape.size();
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
EXPECT_GT(shape.size(), 1); // Verify decode=true took effect
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrNumSamplers) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrNumSamplers.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
// Create a Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds = Flickr(dataset_path, file_path, true, std::make_shared<SequentialSampler>(0, 1));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto annotation = row["annotation"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
|
||||
auto a_it = annotation.Shape().begin();
|
||||
for (; a_it != annotation.Shape().end(); ++a_it) {
|
||||
std::cout << "annotation shape " << *a_it << std::endl;
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 1);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
// Create a Flickr30k Dataset with non-existing dataset dir
|
||||
std::shared_ptr<Dataset> ds0 = Flickr("NotExistFile", file_path);
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid Flickr30k input
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
|
||||
// Create a Flickr30k Dataset with non-existing annotation file
|
||||
std::shared_ptr<Dataset> ds1 = Flickr(dataset_path, "NotExistFile");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid Flickr30k input
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
// Create a Flickr30k Dataset with invalid string of dataset dir
|
||||
std::shared_ptr<Dataset> ds2 = Flickr(":*?\"<>|`&;'", file_path);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid Flickr30k input
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
|
||||
// Create a Flickr30k Dataset with invalid string of annotation file
|
||||
std::shared_ptr<Dataset> ds3 = Flickr(dataset_path, ":*?\"<>|`&;'");
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid Flickr30k input
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFlickrWithNullSamplerError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFlickrWithNullSamplerError.";
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string file_path = datasets_root_path_ + "/testFlickrData/flickr30k/test1.token";
|
||||
// Create a Flickr30k Dataset
|
||||
std::shared_ptr<Dataset> ds = Flickr(dataset_path, file_path, false, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Flickr30k input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
Binary file not shown.
After Width: | Height: | Size: 172 KiB |
Binary file not shown.
After Width: | Height: | Size: 72 KiB |
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
Binary file not shown.
After Width: | Height: | Size: 79 KiB |
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
|
@ -0,0 +1,10 @@
|
|||
000001.jpg#0 This is \*a banana.
|
||||
000001.jpg#1 This is a yellow banana.
|
||||
000001.jpg#2 This is a banana on the table.
|
||||
000001.jpg#3 The banana is yellow.
|
||||
000001.jpg#4 The banana is very big.
|
||||
000002.jpg#0 This is a pen.
|
||||
000002.jpg#1 This is a red and black pen.
|
||||
000002.jpg#2 This is a pen on the table.
|
||||
000002.jpg#3 The color of the pen is red and black.
|
||||
000002.jpg#4 The pen has two colors.
|
|
@ -0,0 +1,15 @@
|
|||
000003.jpg#0 This is an orange.
|
||||
000003.jpg#1 This is an big orange.
|
||||
000003.jpg#2 This is an orange on the pad.
|
||||
000003.jpg#3 The orange looks like very delicious.
|
||||
000003.jpg#4 The orange is on the black pad.
|
||||
000004.jpg#0 This is a key.
|
||||
000004.jpg#1 This is a delicate key.
|
||||
000004.jpg#2 This is a key on the table.
|
||||
000004.jpg#3 The key looks like very old.
|
||||
000004.jpg#4 The key can be used to open door.
|
||||
000005.jpg#0 This is a cup.
|
||||
000005.jpg#1 This is a beautiful and delicate cup.
|
||||
000005.jpg#2 This is a cup on the table.
|
||||
000005.jpg#3 The cup is white and the pattern is blue.
|
||||
000005.jpg#4 The beautiful cup is on the black table.
|
|
@ -0,0 +1,5 @@
|
|||
000001.jpg#0 This is a banana.
|
||||
000001.jpg#1 This is a yellow banana.
|
||||
000001.jpg#2 This is a banana on the table.
|
||||
000001.jpg#3 The banana is yellow.
|
||||
000001.jpg#4 The banana is very big.
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
|
||||
FLICKR30K_DATASET_DIR = "../data/dataset/testFlickrData/flickr30k/flickr30k-images"
|
||||
FLICKR30K_ANNOTATION_FILE_1 = "../data/dataset/testFlickrData/flickr30k/test1.token"
|
||||
FLICKR30K_ANNOTATION_FILE_2 = "../data/dataset/testFlickrData/flickr30k/test2.token"
|
||||
|
||||
|
||||
def visualize_dataset(images, labels):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
plt.figure(figsize=(10, 10))
|
||||
for i, item in enumerate(zip(images, labels), start=1):
|
||||
plt.imshow(item[0])
|
||||
plt.title('\n'.join([s.decode('utf-8') for s in item[1]]))
|
||||
plt.savefig('./flickr_' + str(i) + '.jpg')
|
||||
|
||||
|
||||
def test_flickr30k_dataset_train(plot=False):
|
||||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
count = 0
|
||||
images_list = []
|
||||
annotation_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
images_list.append(item['image'])
|
||||
annotation_list.append(item['annotation'])
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
if plot:
|
||||
visualize_dataset(images_list, annotation_list)
|
||||
|
||||
|
||||
def test_flickr30k_dataset_annotation_check():
|
||||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True, shuffle=False)
|
||||
count = 0
|
||||
expect_annotation_arr = [
|
||||
np.array([
|
||||
r'This is \*a banana.',
|
||||
'This is a yellow banana.',
|
||||
'This is a banana on the table.',
|
||||
'The banana is yellow.',
|
||||
'The banana is very big.',
|
||||
]),
|
||||
np.array([
|
||||
'This is a pen.',
|
||||
'This is a red and black pen.',
|
||||
'This is a pen on the table.',
|
||||
'The color of the pen is red and black.',
|
||||
'The pen has two colors.',
|
||||
])
|
||||
]
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
annotation = [s.decode("utf8") for s in item["annotation"]]
|
||||
np.testing.assert_array_equal(annotation, expect_annotation_arr[count])
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_flickr30k_dataset_basic():
|
||||
# case 1: test num_samples
|
||||
data1 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, num_samples=2, decode=True)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 2
|
||||
|
||||
# case 2: test repeat
|
||||
data2 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
data2 = data2.repeat(5)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 10
|
||||
|
||||
# case 3: test batch with drop_remainder=False
|
||||
data3 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
|
||||
resize_op = c_vision.Resize((100, 100))
|
||||
data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
|
||||
assert data3.get_dataset_size() == 3
|
||||
assert data3.get_batch_size() == 1
|
||||
data3 = data3.batch(batch_size=2) # drop_remainder is default to be False
|
||||
assert data3.get_dataset_size() == 2
|
||||
assert data3.get_batch_size() == 2
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 2
|
||||
|
||||
# case 4: test batch with drop_remainder=True
|
||||
data4 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
|
||||
resize_op = c_vision.Resize((100, 100))
|
||||
data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
|
||||
assert data4.get_dataset_size() == 3
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data4.get_dataset_size() == 1
|
||||
assert data4.get_batch_size() == 2
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 1
|
||||
|
||||
|
||||
def test_flickr30k_dataset_exception():
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flickr30k_dataset_train(False)
|
||||
test_flickr30k_dataset_annotation_check()
|
||||
test_flickr30k_dataset_basic()
|
||||
test_flickr30k_dataset_exception()
|
Loading…
Reference in New Issue