[feat][assistant][I3J6VH] add new data operator SBU

This commit is contained in:
ckczzj 2021-06-02 16:52:40 +08:00
parent cc7a2b74ac
commit af97ba4a77
23 changed files with 1346 additions and 1 deletions

View File

@ -102,6 +102,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
@ -1277,6 +1278,27 @@ RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vecto
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,

View File

@ -44,6 +44,7 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
@ -264,6 +265,16 @@ PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
"to create an SBUNode")
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler) {
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(sbu->ValidateParams());
return sbu;
}));
}));
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode")

View File

@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
cifar_op.cc
random_data_op.cc
celeba_op.cc
sbu_op.cc
text_file_op.cc
clue_op.cc
csv_op.cc

View File

@ -0,0 +1,234 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
SBUOp::SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(folder_path),
decode_(decode),
url_path_(""),
caption_path_(""),
image_folder_(""),
data_schema_(std::move(data_schema)) {
io_block_queues_.Init(num_workers, queue_size);
}
void SBUOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\nSBU directory: " << folder_path_
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
// Load 1 TensorRow (image, caption) using 1 SBUImageCaptionPair.
Status SBUOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
SBUImageCaptionPair image_caption_pair = image_caption_pairs_[row_id];
Path path = image_caption_pair.first;
std::shared_ptr<Tensor> image, caption;
RETURN_IF_NOT_OK(ReadImageToTensor(path.ToString(), &image));
RETURN_IF_NOT_OK(Tensor::CreateScalar(image_caption_pair.second, &caption));
(*trow) = TensorRow(row_id, {std::move(image), std::move(caption)});
trow->setPath({path.ToString()});
return Status::OK();
}
Status SBUOp::ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) {
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
if (decode_ == true) {
Status rc = Decode(*tensor, tensor);
if (rc.IsError()) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to decode image: " + path);
}
}
return Status::OK();
}
Status SBUOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
Status SBUOp::CountTotalRows(const std::string &dir, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connector_size = cfg->op_connector_size();
// compat does not affect the count result, so set it to true default.
auto op = std::make_shared<SBUOp>(dir, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
// the logic of counting the number of samples
RETURN_IF_NOT_OK(op->ParseSBUData());
*count = op->image_caption_pairs_.size();
return Status::OK();
}
Status SBUOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&SBUOp::WorkerEntry, this, std::placeholders::_1), "", id()));
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(this->ParseSBUData());
RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler
return Status::OK();
}
Status SBUOp::ParseSBUData() {
const Path url_file_name("SBU_captioned_photo_dataset_urls.txt");
const Path caption_file_name("SBU_captioned_photo_dataset_captions.txt");
const Path image_folder_name("sbu_images");
auto real_folder_path = Common::GetRealPath(folder_path_);
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
Path root_dir(real_folder_path.value());
url_path_ = root_dir / url_file_name;
CHECK_FAIL_RETURN_UNEXPECTED(url_path_.Exists() && !url_path_.IsDirectory(),
"Invalid file, failed to find SBU url file: " + url_path_.ToString());
MS_LOG(INFO) << "SBU operator found url file " << url_path_.ToString() << ".";
caption_path_ = root_dir / caption_file_name;
CHECK_FAIL_RETURN_UNEXPECTED(caption_path_.Exists() && !caption_path_.IsDirectory(),
"Invalid file, failed to find SBU caption file: " + caption_path_.ToString());
MS_LOG(INFO) << "SBU operator found caption file " << caption_path_.ToString() << ".";
image_folder_ = root_dir / image_folder_name;
CHECK_FAIL_RETURN_UNEXPECTED(image_folder_.Exists() && image_folder_.IsDirectory(),
"Invalid folder, failed to find SBU image folder: " + image_folder_.ToString());
MS_LOG(INFO) << "SBU operator found image folder " << image_folder_.ToString() << ".";
std::ifstream url_file_reader;
std::ifstream caption_file_reader;
url_file_reader.open(url_path_.ToString(), std::ios::in);
caption_file_reader.open(caption_path_.ToString(), std::ios::in);
CHECK_FAIL_RETURN_UNEXPECTED(url_file_reader.is_open(),
"Invalid file, failed to open SBU url file: " + url_path_.ToString());
CHECK_FAIL_RETURN_UNEXPECTED(caption_file_reader.is_open(),
"Invalid file, failed to open SBU caption file: " + caption_path_.ToString());
Status rc = GetAvailablePairs(url_file_reader, caption_file_reader);
url_file_reader.close();
caption_file_reader.close();
if (rc.IsError()) {
return rc;
}
return Status::OK();
}
Status SBUOp::GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader) {
std::string url_line;
std::string caption_line;
int64_t line_num = 0;
while (std::getline(url_file_reader, url_line) && std::getline(caption_file_reader, caption_line)) {
CHECK_FAIL_RETURN_UNEXPECTED(
(url_line.empty() && caption_line.empty()) || (!url_line.empty() && !caption_line.empty()),
"Invalid data, SBU url and caption file are mismatched: " + url_path_.ToString() + " and " +
caption_path_.ToString());
if (!url_line.empty() && !caption_line.empty()) {
line_num++;
RETURN_IF_NOT_OK(this->ParsePair(url_line, caption_line));
}
}
image_caption_pairs_.shrink_to_fit();
CHECK_FAIL_RETURN_UNEXPECTED(image_caption_pairs_.size() > 0, "No valid images in " + image_folder_.ToString());
// base field of RandomAccessOp
num_rows_ = image_caption_pairs_.size();
return Status::OK();
}
Status SBUOp::ParsePair(const std::string &url, const std::string &caption) {
std::string image_name = url.substr(23, std::string::npos);
RETURN_IF_NOT_OK(this->ReplaceAll(&image_name, "/", "_"));
Path image_path = image_folder_ / Path(image_name);
if (image_path.Exists() && !image_path.IsDirectory()) {
// rstrip caption
image_caption_pairs_.emplace_back(std::make_pair(image_path, caption.substr(0, caption.find_last_not_of(" ") + 1)));
}
return Status::OK();
}
Status SBUOp::ReplaceAll(std::string *str, const std::string &from, const std::string &to) {
size_t pos = 0;
while ((pos = str->find(from, pos)) != std::string::npos) {
str->replace(pos, from.length(), to);
pos += to.length();
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
using SBUImageCaptionPair = std::pair<Path, std::string>;
class SBUOp : public MappableLeafOp {
public:
// Constructor.
// @param const std::string &folder_path - dir directory of SBU data file.
// @param bool decode - whether to decode images.
// @param std::unique_ptr<DataSchema> data_schema - the schema of the SBU dataset.
// @param std::unique_ptr<Sampler> sampler - sampler tells SBUOp what to read.
// @param int32_t num_workers - number of workers reading images in parallel.
// @param int32_t queue_size - connector queue size.
SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size);
// Destructor.
~SBUOp() = default;
// Op name getter.
// @return std::string - Name of the current Op.
std::string Name() const override { return "SBUOp"; }
// A print method typically used for debugging.
// @param std::ostream &out - out stream.
// @param bool show_all - whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the SBU dataset.
// @param const std::string &dir - path to the SBU directory.
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
// @return Status - The status code returned.
static Status CountTotalRows(const std::string &dir, int64_t *count);
private:
// Load a tensor row according to a pair.
// @param row_id_type row_id - id for this tensor row.
// @param TensorRow row - image & label read into this tensor row.
// @return Status - The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
// Private function for computing the assignment of the column name map.
// @return Status - The status code returned.
Status ComputeColMap() override;
// Called first when function is called.
// @return Status - The status code returned.
Status LaunchThreadsAndInitOp() override;
// @param const std::string &path - path to the image file.
// @param std::shared_ptr<Tensor> tensor - tensor to store image.
// @return Status - The status code returned.
Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor);
// Parse SBU data file.
// @return Status - The status code returned.
Status ParseSBUData();
// Get available image-caption pairs.
// @param std::ifstream &url_file_reader - url file reader.
// @param std::ifstream &caption_file_reader - caption file reader.
// @return Status - The status code returned.
Status GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader);
// Parse path-caption pair.
// @param const std::string &url - image url.
// @param const std::string &caption - caption.
// @return Status - The status code returned.
Status ParsePair(const std::string &url, const std::string &caption);
// A util for string replace.
// @param std::string *str - string to be replaces.
// @param const std::string &from - string from.
// @param const std::string &to - string to.
// @return Status - The status code returned.
Status ReplaceAll(std::string *str, const std::string &from, const std::string &to);
std::string folder_path_; // directory of data files
const bool decode_;
std::unique_ptr<DataSchema> data_schema_;
Path url_path_;
Path caption_path_;
Path image_folder_;
std::vector<SBUImageCaptionPair> image_caption_pairs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_

View File

@ -90,6 +90,7 @@ constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kSBUNode[] = "SBUDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kUSPSNode[] = "USPSDataset";

View File

@ -18,6 +18,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
minddata_node.cc
mnist_node.cc
random_node.cc
sbu_node.cc
text_file_node.cc
tf_record_node.cc
usps_node.cc

View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
SBUNode::SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
std::shared_ptr<DatasetNode> SBUNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<SBUNode>(dataset_dir_, decode_, sampler, cache_);
return node;
}
void SBUNode::Print(std::ostream &out) const {
out << (Name() + "(dataset dir: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") +
", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status SBUNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUNode", sampler_));
Path root_dir(dataset_dir_);
Path url_path = root_dir / Path("SBU_captioned_photo_dataset_urls.txt");
Path caption_path = root_dir / Path("SBU_captioned_photo_dataset_captions.txt");
Path image_path = root_dir / Path("sbu_images");
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {url_path.ToString()}));
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {caption_path.ToString()}));
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", {image_path.ToString()}));
return Status::OK();
}
Status SBUNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_,
connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status SBUNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status SBUNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(SBUOp::CountTotalRows(dataset_dir_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status SBUNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class SBUNode : public MappableSourceNode {
public:
/// \brief Constructor.
SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~SBUNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kSBUNode; }
/// \brief Print the description.
/// \param out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
bool Decode() const { return decode_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_

View File

@ -2332,6 +2332,78 @@ std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, con
return ds;
}
/// \class SBUDataset
/// \brief A source dataset that reads and parses SBU dataset.
class SBUDataset : public Dataset {
public:
/// \brief Constructor of SBUDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of SBUDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of SBUDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache);
/// Destructor of SBUDataset.
~SBUDataset() = default;
};
/// \brief Function to create a SBUDataset.
/// \notes The generated dataset has two columns ["image", "caption"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current SBUDataset.
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \brief Function to create a SBUDataset.
/// \notes The generated dataset has two columns ["image", "caption"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current SBUDataset.
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \brief Function to create a SBUDataset.
/// \notes The generated dataset has two columns ["image", "caption"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current SBUDataset.
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
/// \class TextFileDataset
/// \brief A source dataset that reads and parses datasets stored on disk in text format.
class TextFileDataset : public Dataset {

View File

@ -45,6 +45,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class MindDataDataset;
friend class MnistDataset;
friend class RandomDataDataset;
friend class SBUDataset;
friend class TextFileDataset;
friend class TFRecordDataset;
friend class USPSDataset;

View File

@ -64,7 +64,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -5612,6 +5613,120 @@ class CSVDataset(SourceDataset):
self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
class SBUDataset(MappableDataset):
"""
A source dataset for reading and parsing the SBU dataset.
The generated dataset has two columns :py:obj:`[image, caption]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`caption` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
decode (bool, optional): Decode the images after reading (default=False).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will read all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, will use value set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> sbu_dataset_dir = "/path/to/sbu_dataset_directory"
>>> # Read 3 samples from SBU dataset
>>> dataset = ds.SBUDataset(dataset_dir=sbu_dataset_dir, num_samples=3)
About SBU dataset:
SBU dataset is a large captioned photo collection.
It contains one million images with associated visually relevant captions.
You should manually download the images using official download.m by replacing 'urls{i}(24, end)' with
'urls{i}(24:1:end)' and keep the directory as below.
.. code-block::
.
dataset_dir
SBU_captioned_photo_dataset_captions.txt
SBU_captioned_photo_dataset_urls.txt
sbu_images
m_3326_3596303505_3ce4c20529.jpg
......
m_2522_4182181099_c3c23ab1cc.jpg
Citation:
.. code-block::
@inproceedings{Ordonez:2011:im2text,
Author = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg},
Title = {Im2Text: Describing Images Using 1 Million Captioned Photographs},
Booktitle = {Neural Information Processing Systems ({NIPS})},
Year = {2011},
}
"""
@check_sbu_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.SBUNode(self.dataset_dir, self.decode, self.sampler)
class _Flowers102Dataset:
"""
Mainly for loading Flowers102 Dataset, and return one row each time.

View File

@ -122,6 +122,36 @@ def check_manifestdataset(method):
return new_method
def check_sbu_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
check_dir(os.path.join(dataset_dir, "sbu_images"))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_tfrecorddataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""

View File

@ -30,6 +30,7 @@ SET(DE_UT_SRCS
c_api_dataset_ops_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_save.cc
c_api_dataset_sbu_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_usps_test.cc

View File

@ -0,0 +1,188 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestSBUDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDataset.";
// Create a SBU Dataset
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
std::shared_ptr<Dataset> ds = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("caption"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSBUDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithPipeline.";
// Create two SBU Dataset
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
std::shared_ptr<Dataset> ds1 = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
std::shared_ptr<Dataset> ds2 = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "caption"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("caption"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestGetSBUDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSBUDatasetSize.";
// Create a SBU Dataset
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
std::shared_ptr<Dataset> ds = SBU(folder_path, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 5);
}
TEST_F(MindDataTestPipeline, TestSBUDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetGetters.";
// Create a SBU Dataset
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
std::shared_ptr<Dataset> ds = SBU(folder_path, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 5);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "caption"};
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "string");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 5);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 5);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 5);
}
TEST_F(MindDataTestPipeline, TestSBUDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetFail.";
// Create a SBU Dataset
std::shared_ptr<Dataset> ds = SBU("", true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid SBU input
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestSBUDatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithNullSamplerFail.";
// Create a SBU Dataset
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
std::shared_ptr<Dataset> ds = SBU(folder_path, true, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid SBU input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,10 @@
This is 0
This is 1.
This is 2
This is 3.
This is 4
This is 5.
This is 6
This is 7.
This is 8
This is 9.

View File

@ -0,0 +1,10 @@
http://123456.123456.123/123/123456789_123456780.jpg
http://123456.123456.123/123/123456789_123456781.jpg
http://123456.123456.123/123/123456789_123456782.jpg
http://123456.123456.123/123/123456789_123456783.jpg
http://123456.123456.123/123/123456789_123456784.jpg
http://123456.123456.123/123/123456789_123456785.jpg
http://123456.123456.123/123/123456789_123456786.jpg
http://123456.123456.123/123/123456789_123456787.jpg
http://123456.123456.123/123/123456789_123456788.jpg
http://123456.123456.123/123/123456789_123456789.jpg

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

View File

@ -0,0 +1,303 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test USPS dataset operators
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pytest
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testSBUDataset"
WRONG_DIR = "../data/dataset/testMnistData"
def load_sbu(path):
"""
load SBU data
"""
images = []
captions = []
file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt'))
file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt'))
for line1, line2 in zip(open(file1), open(file2)):
url = line1.rstrip()
image = url[23:].replace("/", "_")
filename = os.path.join(path, 'sbu_images', image)
if os.path.exists(filename):
caption = line2.rstrip()
images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8))
captions.append(caption)
return images, captions
def visualize_dataset(images, captions):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i].squeeze())
plt.title(captions[i])
plt.show()
def test_sbu_content_check():
"""
Validate SBUDataset image readings
"""
logger.info("Test SBUDataset Op with content check")
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False)
images, captions = load_sbu(DATA_DIR)
num_iter = 0
# in this example, each dictionary has keys "image" and "caption"
for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
assert data["image"].shape == images[i].shape
assert data["caption"].item().decode("utf8") == captions[i]
num_iter += 1
assert num_iter == 5
def test_sbu_case():
"""
Validate SBUDataset cases
"""
dataset = ds.SBUDataset(DATA_DIR, decode=True)
dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"])
repeat_num = 4
dataset = dataset.repeat(repeat_num)
batch_size = 2
dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
num = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num += 1
# 4 x 5 / 2
assert num == 10
dataset = ds.SBUDataset(DATA_DIR, decode=False)
dataset = dataset.map(operations=[vision.Decode(rgb=True), vision.Resize((224, 224))], input_columns=["image"])
repeat_num = 4
dataset = dataset.repeat(repeat_num)
batch_size = 2
dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
num = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num += 1
# 4 x 5 / 2
assert num == 10
def test_sbu_basic():
"""
Validate SBUDataset
"""
logger.info("Test SBUDataset Op")
# case 1: test loading whole dataset
dataset = ds.SBUDataset(DATA_DIR, decode=True)
num_iter = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 5
# case 2: test num_samples
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
num_iter = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 5
# case 3: test repeat
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
dataset = dataset.repeat(5)
num_iter = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 25
# case 4: test batch
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
assert dataset.get_dataset_size() == 5
assert dataset.get_batch_size() == 1
num_iter = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 5
# case 5: test get_class_indexing
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
assert dataset.get_class_indexing() == {}
# case 6: test get_col_names
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
assert dataset.get_col_names() == ["image", "caption"]
def test_sbu_sequential_sampler():
"""
Test SBUDataset with SequentialSampler
"""
logger.info("Test SBUDataset Op with SequentialSampler")
num_samples = 5
sampler = ds.SequentialSampler(num_samples=num_samples)
dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler)
dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples)
num_iter = 0
for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["caption"], item2["caption"])
num_iter += 1
assert num_iter == num_samples
def test_sbu_exception():
"""
Test error cases for SBUDataset
"""
logger.info("Test error cases for SBUDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler())
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.SBUDataset(DATA_DIR, decode=True, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.SBUDataset(DATA_DIR, decode=True, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
error_msg_8 = "The corresponding data files"
with pytest.raises(RuntimeError, match=error_msg_8):
dataset = ds.SBUDataset(DATA_DIR, decode=True)
dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in dataset.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
dataset = ds.SBUDataset(DATA_DIR, decode=True)
dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
for _ in dataset.__iter__():
pass
error_msg_9 = "does not exist or permission denied"
with pytest.raises(ValueError, match=error_msg_9):
dataset = ds.SBUDataset(WRONG_DIR, decode=True)
for _ in dataset.__iter__():
pass
error_msg_10 = "Argument decode with value"
with pytest.raises(TypeError, match=error_msg_10):
dataset = ds.SBUDataset(DATA_DIR, decode="not_bool")
for _ in dataset.__iter__():
pass
def test_sbu_visualize(plot=False):
"""
Visualize SBUDataset results
"""
logger.info("Test SBUDataset visualization")
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False)
num_iter = 0
image_list, caption_list = [], []
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
caption = item["caption"].item().decode("utf8")
image_list.append(image)
caption_list.append("caption {}".format(caption))
assert isinstance(image, np.ndarray)
assert image.dtype == np.uint8
assert isinstance(caption, str)
num_iter += 1
assert num_iter == 5
if plot:
visualize_dataset(image_list, caption_list)
def test_sbu_decode():
"""
Validate SBUDataset image readings
"""
logger.info("Test SBUDataset decode flag")
sampler = ds.SequentialSampler(num_samples=50)
dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler)
dataset_1 = dataset.map(operations=[vision.Decode(rgb=True)], input_columns=["image"])
dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler)
num_iter = 0
for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["caption"], item2["caption"])
num_iter += 1
assert num_iter == 5
if __name__ == '__main__':
test_sbu_content_check()
test_sbu_basic()
test_sbu_case()
test_sbu_sequential_sampler()
test_sbu_exception()
test_sbu_visualize(plot=True)
test_sbu_decode()