forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6VH] add new data operator SBU
This commit is contained in:
parent
cc7a2b74ac
commit
af97ba4a77
|
@ -102,6 +102,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
@ -1277,6 +1278,27 @@ RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vecto
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
|
||||
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
@ -264,6 +265,16 @@ PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
|
||||
"to create an SBUNode")
|
||||
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler) {
|
||||
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(sbu->ValidateParams());
|
||||
return sbu;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
||||
"to create a TextFileNode")
|
||||
|
|
|
@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
cifar_op.cc
|
||||
random_data_op.cc
|
||||
celeba_op.cc
|
||||
sbu_op.cc
|
||||
text_file_op.cc
|
||||
clue_op.cc
|
||||
csv_op.cc
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SBUOp::SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(folder_path),
|
||||
decode_(decode),
|
||||
url_path_(""),
|
||||
caption_path_(""),
|
||||
image_folder_(""),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers, queue_size);
|
||||
}
|
||||
|
||||
void SBUOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nSBU directory: " << folder_path_
|
||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image, caption) using 1 SBUImageCaptionPair.
|
||||
Status SBUOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
SBUImageCaptionPair image_caption_pair = image_caption_pairs_[row_id];
|
||||
Path path = image_caption_pair.first;
|
||||
|
||||
std::shared_ptr<Tensor> image, caption;
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(path.ToString(), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(image_caption_pair.second, &caption));
|
||||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(caption)});
|
||||
trow->setPath({path.ToString()});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
|
||||
if (decode_ == true) {
|
||||
Status rc = Decode(*tensor, tensor);
|
||||
if (rc.IsError()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to decode image: " + path);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
|
||||
// compat does not affect the count result, so set it to true default.
|
||||
auto op = std::make_shared<SBUOp>(dir, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
|
||||
|
||||
// the logic of counting the number of samples
|
||||
RETURN_IF_NOT_OK(op->ParseSBUData());
|
||||
*count = op->image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&SBUOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
RETURN_IF_NOT_OK(this->ParseSBUData());
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ParseSBUData() {
|
||||
const Path url_file_name("SBU_captioned_photo_dataset_urls.txt");
|
||||
const Path caption_file_name("SBU_captioned_photo_dataset_captions.txt");
|
||||
const Path image_folder_name("sbu_images");
|
||||
auto real_folder_path = Common::GetRealPath(folder_path_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
|
||||
Path root_dir(real_folder_path.value());
|
||||
|
||||
url_path_ = root_dir / url_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_path_.Exists() && !url_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU url file: " + url_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found url file " << url_path_.ToString() << ".";
|
||||
|
||||
caption_path_ = root_dir / caption_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_path_.Exists() && !caption_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU caption file: " + caption_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found caption file " << caption_path_.ToString() << ".";
|
||||
|
||||
image_folder_ = root_dir / image_folder_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_folder_.Exists() && image_folder_.IsDirectory(),
|
||||
"Invalid folder, failed to find SBU image folder: " + image_folder_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found image folder " << image_folder_.ToString() << ".";
|
||||
|
||||
std::ifstream url_file_reader;
|
||||
std::ifstream caption_file_reader;
|
||||
|
||||
url_file_reader.open(url_path_.ToString(), std::ios::in);
|
||||
caption_file_reader.open(caption_path_.ToString(), std::ios::in);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU url file: " + url_path_.ToString());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU caption file: " + caption_path_.ToString());
|
||||
|
||||
Status rc = GetAvailablePairs(url_file_reader, caption_file_reader);
|
||||
url_file_reader.close();
|
||||
caption_file_reader.close();
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader) {
|
||||
std::string url_line;
|
||||
std::string caption_line;
|
||||
int64_t line_num = 0;
|
||||
|
||||
while (std::getline(url_file_reader, url_line) && std::getline(caption_file_reader, caption_line)) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
(url_line.empty() && caption_line.empty()) || (!url_line.empty() && !caption_line.empty()),
|
||||
"Invalid data, SBU url and caption file are mismatched: " + url_path_.ToString() + " and " +
|
||||
caption_path_.ToString());
|
||||
if (!url_line.empty() && !caption_line.empty()) {
|
||||
line_num++;
|
||||
RETURN_IF_NOT_OK(this->ParsePair(url_line, caption_line));
|
||||
}
|
||||
}
|
||||
|
||||
image_caption_pairs_.shrink_to_fit();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_caption_pairs_.size() > 0, "No valid images in " + image_folder_.ToString());
|
||||
|
||||
// base field of RandomAccessOp
|
||||
num_rows_ = image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ParsePair(const std::string &url, const std::string &caption) {
|
||||
std::string image_name = url.substr(23, std::string::npos);
|
||||
RETURN_IF_NOT_OK(this->ReplaceAll(&image_name, "/", "_"));
|
||||
|
||||
Path image_path = image_folder_ / Path(image_name);
|
||||
if (image_path.Exists() && !image_path.IsDirectory()) {
|
||||
// rstrip caption
|
||||
image_caption_pairs_.emplace_back(std::make_pair(image_path, caption.substr(0, caption.find_last_not_of(" ") + 1)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReplaceAll(std::string *str, const std::string &from, const std::string &to) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str->find(from, pos)) != std::string::npos) {
|
||||
str->replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using SBUImageCaptionPair = std::pair<Path, std::string>;
|
||||
|
||||
class SBUOp : public MappableLeafOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param const std::string &folder_path - dir directory of SBU data file.
|
||||
// @param bool decode - whether to decode images.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the SBU dataset.
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells SBUOp what to read.
|
||||
// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
// @param int32_t queue_size - connector queue size.
|
||||
SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size);
|
||||
|
||||
// Destructor.
|
||||
~SBUOp() = default;
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "SBUOp"; }
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the SBU dataset.
|
||||
// @param const std::string &dir - path to the SBU directory.
|
||||
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
// @return Status - The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||
|
||||
private:
|
||||
// Load a tensor row according to a pair.
|
||||
// @param row_id_type row_id - id for this tensor row.
|
||||
// @param TensorRow row - image & label read into this tensor row.
|
||||
// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// Called first when function is called.
|
||||
// @return Status - The status code returned.
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
// @param const std::string &path - path to the image file.
|
||||
// @param std::shared_ptr<Tensor> tensor - tensor to store image.
|
||||
// @return Status - The status code returned.
|
||||
Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// Parse SBU data file.
|
||||
// @return Status - The status code returned.
|
||||
Status ParseSBUData();
|
||||
|
||||
// Get available image-caption pairs.
|
||||
// @param std::ifstream &url_file_reader - url file reader.
|
||||
// @param std::ifstream &caption_file_reader - caption file reader.
|
||||
// @return Status - The status code returned.
|
||||
Status GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader);
|
||||
|
||||
// Parse path-caption pair.
|
||||
// @param const std::string &url - image url.
|
||||
// @param const std::string &caption - caption.
|
||||
// @return Status - The status code returned.
|
||||
Status ParsePair(const std::string &url, const std::string &caption);
|
||||
|
||||
// A util for string replace.
|
||||
// @param std::string *str - string to be replaces.
|
||||
// @param const std::string &from - string from.
|
||||
// @param const std::string &to - string to.
|
||||
// @return Status - The status code returned.
|
||||
Status ReplaceAll(std::string *str, const std::string &from, const std::string &to);
|
||||
|
||||
std::string folder_path_; // directory of data files
|
||||
const bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
Path url_path_;
|
||||
Path caption_path_;
|
||||
Path image_folder_;
|
||||
std::vector<SBUImageCaptionPair> image_caption_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
|
@ -90,6 +90,7 @@ constexpr char kManifestNode[] = "ManifestDataset";
|
|||
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||
constexpr char kMnistNode[] = "MnistDataset";
|
||||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kSBUNode[] = "SBUDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
constexpr char kUSPSNode[] = "USPSDataset";
|
||||
|
|
|
@ -18,6 +18,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
minddata_node.cc
|
||||
mnist_node.cc
|
||||
random_node.cc
|
||||
sbu_node.cc
|
||||
text_file_node.cc
|
||||
tf_record_node.cc
|
||||
usps_node.cc
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
SBUNode::SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> SBUNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<SBUNode>(dataset_dir_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void SBUNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(dataset dir: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") +
|
||||
", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status SBUNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUNode", sampler_));
|
||||
|
||||
Path root_dir(dataset_dir_);
|
||||
|
||||
Path url_path = root_dir / Path("SBU_captioned_photo_dataset_urls.txt");
|
||||
Path caption_path = root_dir / Path("SBU_captioned_photo_dataset_captions.txt");
|
||||
Path image_path = root_dir / Path("sbu_images");
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {url_path.ToString()}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {caption_path.ToString()}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", {image_path.ToString()}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_,
|
||||
connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status SBUNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status SBUNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(SBUOp::CountTotalRows(dataset_dir_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["decode"] = decode_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SBUNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SBUNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kSBUNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
bool Decode() const { return decode_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_
|
|
@ -2332,6 +2332,78 @@ std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, con
|
|||
return ds;
|
||||
}
|
||||
|
||||
/// \class SBUDataset
|
||||
/// \brief A source dataset that reads and parses SBU dataset.
|
||||
class SBUDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of SBUDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of SBUDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of SBUDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of SBUDataset.
|
||||
~SBUDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a SBUDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "caption"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current SBUDataset.
|
||||
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a SBUDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "caption"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current SBUDataset.
|
||||
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a SBUDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "caption"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current SBUDataset.
|
||||
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class TextFileDataset
|
||||
/// \brief A source dataset that reads and parses datasets stored on disk in text format.
|
||||
class TextFileDataset : public Dataset {
|
||||
|
|
|
@ -45,6 +45,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class MindDataDataset;
|
||||
friend class MnistDataset;
|
||||
friend class RandomDataDataset;
|
||||
friend class SBUDataset;
|
||||
friend class TextFileDataset;
|
||||
friend class TFRecordDataset;
|
||||
friend class USPSDataset;
|
||||
|
|
|
@ -64,7 +64,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
|
||||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -5612,6 +5613,120 @@ class CSVDataset(SourceDataset):
|
|||
self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
|
||||
|
||||
|
||||
class SBUDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the SBU dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, caption]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`caption` is of the string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
decode (bool, optional): Decode the images after reading (default=False).
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter 'sampler'
|
||||
- Parameter 'shuffle'
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> sbu_dataset_dir = "/path/to/sbu_dataset_directory"
|
||||
>>> # Read 3 samples from SBU dataset
|
||||
>>> dataset = ds.SBUDataset(dataset_dir=sbu_dataset_dir, num_samples=3)
|
||||
|
||||
About SBU dataset:
|
||||
|
||||
SBU dataset is a large captioned photo collection.
|
||||
It contains one million images with associated visually relevant captions.
|
||||
|
||||
You should manually download the images using official download.m by replacing 'urls{i}(24, end)' with
|
||||
'urls{i}(24:1:end)' and keep the directory as below.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└─ dataset_dir
|
||||
├── SBU_captioned_photo_dataset_captions.txt
|
||||
├── SBU_captioned_photo_dataset_urls.txt
|
||||
└── sbu_images
|
||||
├── m_3326_3596303505_3ce4c20529.jpg
|
||||
├── ......
|
||||
└── m_2522_4182181099_c3c23ab1cc.jpg
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@inproceedings{Ordonez:2011:im2text,
|
||||
Author = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg},
|
||||
Title = {Im2Text: Describing Images Using 1 Million Captioned Photographs},
|
||||
Booktitle = {Neural Information Processing Systems ({NIPS})},
|
||||
Year = {2011},
|
||||
}
|
||||
"""
|
||||
|
||||
@check_sbu_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False,
|
||||
sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.SBUNode(self.dataset_dir, self.decode, self.sampler)
|
||||
|
||||
|
||||
class _Flowers102Dataset:
|
||||
"""
|
||||
Mainly for loading Flowers102 Dataset, and return one row each time.
|
||||
|
|
|
@ -122,6 +122,36 @@ def check_manifestdataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_sbu_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
|
||||
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
|
||||
check_dir(os.path.join(dataset_dir, "sbu_images"))
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_tfrecorddataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_ops_test.cc
|
||||
c_api_dataset_randomdata_test.cc
|
||||
c_api_dataset_save.cc
|
||||
c_api_dataset_sbu_test.cc
|
||||
c_api_dataset_textfile_test.cc
|
||||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_usps_test.cc
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSBUDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDataset.";
|
||||
|
||||
// Create a SBU Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
|
||||
std::shared_ptr<Dataset> ds = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("caption"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSBUDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithPipeline.";
|
||||
|
||||
// Create two SBU Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
|
||||
std::shared_ptr<Dataset> ds1 = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
|
||||
std::shared_ptr<Dataset> ds2 = SBU(folder_path, true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "caption"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("caption"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestGetSBUDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSBUDatasetSize.";
|
||||
|
||||
// Create a SBU Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
|
||||
std::shared_ptr<Dataset> ds = SBU(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSBUDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetGetters.";
|
||||
|
||||
// Create a SBU Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
|
||||
std::shared_ptr<Dataset> ds = SBU(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "caption"};
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "string");
|
||||
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSBUDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetFail.";
|
||||
|
||||
// Create a SBU Dataset
|
||||
std::shared_ptr<Dataset> ds = SBU("", true, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid SBU input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSBUDatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a SBU Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testSBUDataset/";
|
||||
std::shared_ptr<Dataset> ds = SBU(folder_path, true, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid SBU input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
This is 0
|
||||
This is 1.
|
||||
This is 2
|
||||
This is 3.
|
||||
This is 4
|
||||
This is 5.
|
||||
This is 6
|
||||
This is 7.
|
||||
This is 8
|
||||
This is 9.
|
|
@ -0,0 +1,10 @@
|
|||
http://123456.123456.123/123/123456789_123456780.jpg
|
||||
http://123456.123456.123/123/123456789_123456781.jpg
|
||||
http://123456.123456.123/123/123456789_123456782.jpg
|
||||
http://123456.123456.123/123/123456789_123456783.jpg
|
||||
http://123456.123456.123/123/123456789_123456784.jpg
|
||||
http://123456.123456.123/123/123456789_123456785.jpg
|
||||
http://123456.123456.123/123/123456789_123456786.jpg
|
||||
http://123456.123456.123/123/123456789_123456787.jpg
|
||||
http://123456.123456.123/123/123456789_123456788.jpg
|
||||
http://123456.123456.123/123/123456789_123456789.jpg
|
Binary file not shown.
After Width: | Height: | Size: 155 KiB |
Binary file not shown.
After Width: | Height: | Size: 172 KiB |
Binary file not shown.
After Width: | Height: | Size: 207 KiB |
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
Binary file not shown.
After Width: | Height: | Size: 53 KiB |
|
@ -0,0 +1,303 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test USPS dataset operators
|
||||
"""
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testSBUDataset"
|
||||
WRONG_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
|
||||
def load_sbu(path):
|
||||
"""
|
||||
load SBU data
|
||||
"""
|
||||
images = []
|
||||
captions = []
|
||||
|
||||
file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt'))
|
||||
file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt'))
|
||||
|
||||
for line1, line2 in zip(open(file1), open(file2)):
|
||||
url = line1.rstrip()
|
||||
image = url[23:].replace("/", "_")
|
||||
filename = os.path.join(path, 'sbu_images', image)
|
||||
if os.path.exists(filename):
|
||||
caption = line2.rstrip()
|
||||
images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8))
|
||||
captions.append(caption)
|
||||
return images, captions
|
||||
|
||||
|
||||
def visualize_dataset(images, captions):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
num_samples = len(images)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images[i].squeeze())
|
||||
plt.title(captions[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_sbu_content_check():
|
||||
"""
|
||||
Validate SBUDataset image readings
|
||||
"""
|
||||
logger.info("Test SBUDataset Op with content check")
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False)
|
||||
images, captions = load_sbu(DATA_DIR)
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "caption"
|
||||
for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert data["image"].shape == images[i].shape
|
||||
assert data["caption"].item().decode("utf8") == captions[i]
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
def test_sbu_case():
|
||||
"""
|
||||
Validate SBUDataset cases
|
||||
"""
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True)
|
||||
|
||||
dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"])
|
||||
repeat_num = 4
|
||||
dataset = dataset.repeat(repeat_num)
|
||||
batch_size = 2
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
|
||||
|
||||
num = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num += 1
|
||||
# 4 x 5 / 2
|
||||
assert num == 10
|
||||
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=False)
|
||||
|
||||
dataset = dataset.map(operations=[vision.Decode(rgb=True), vision.Resize((224, 224))], input_columns=["image"])
|
||||
repeat_num = 4
|
||||
dataset = dataset.repeat(repeat_num)
|
||||
batch_size = 2
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
|
||||
|
||||
num = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num += 1
|
||||
# 4 x 5 / 2
|
||||
assert num == 10
|
||||
|
||||
|
||||
def test_sbu_basic():
|
||||
"""
|
||||
Validate SBUDataset
|
||||
"""
|
||||
logger.info("Test SBUDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True)
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
# case 2: test num_samples
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
# case 3: test repeat
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
|
||||
dataset = dataset.repeat(5)
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 25
|
||||
|
||||
# case 4: test batch
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
|
||||
assert dataset.get_dataset_size() == 5
|
||||
assert dataset.get_batch_size() == 1
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
# case 5: test get_class_indexing
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
|
||||
assert dataset.get_class_indexing() == {}
|
||||
|
||||
# case 6: test get_col_names
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
|
||||
assert dataset.get_col_names() == ["image", "caption"]
|
||||
|
||||
|
||||
def test_sbu_sequential_sampler():
|
||||
"""
|
||||
Test SBUDataset with SequentialSampler
|
||||
"""
|
||||
logger.info("Test SBUDataset Op with SequentialSampler")
|
||||
num_samples = 5
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler)
|
||||
dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples)
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["caption"], item2["caption"])
|
||||
num_iter += 1
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_sbu_exception():
|
||||
"""
|
||||
Test error cases for SBUDataset
|
||||
"""
|
||||
logger.info("Test error cases for SBUDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler())
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
error_msg_8 = "The corresponding data files"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True)
|
||||
dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in dataset.__iter__():
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True)
|
||||
dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in dataset.__iter__():
|
||||
pass
|
||||
|
||||
error_msg_9 = "does not exist or permission denied"
|
||||
with pytest.raises(ValueError, match=error_msg_9):
|
||||
dataset = ds.SBUDataset(WRONG_DIR, decode=True)
|
||||
for _ in dataset.__iter__():
|
||||
pass
|
||||
|
||||
error_msg_10 = "Argument decode with value"
|
||||
with pytest.raises(TypeError, match=error_msg_10):
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode="not_bool")
|
||||
for _ in dataset.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_sbu_visualize(plot=False):
|
||||
"""
|
||||
Visualize SBUDataset results
|
||||
"""
|
||||
logger.info("Test SBUDataset visualization")
|
||||
|
||||
dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, caption_list = [], []
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
caption = item["caption"].item().decode("utf8")
|
||||
image_list.append(image)
|
||||
caption_list.append("caption {}".format(caption))
|
||||
assert isinstance(image, np.ndarray)
|
||||
|
||||
assert image.dtype == np.uint8
|
||||
assert isinstance(caption, str)
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
if plot:
|
||||
visualize_dataset(image_list, caption_list)
|
||||
|
||||
|
||||
def test_sbu_decode():
|
||||
"""
|
||||
Validate SBUDataset image readings
|
||||
"""
|
||||
logger.info("Test SBUDataset decode flag")
|
||||
|
||||
sampler = ds.SequentialSampler(num_samples=50)
|
||||
dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler)
|
||||
dataset_1 = dataset.map(operations=[vision.Decode(rgb=True)], input_columns=["image"])
|
||||
|
||||
dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["caption"], item2["caption"])
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sbu_content_check()
|
||||
test_sbu_basic()
|
||||
test_sbu_case()
|
||||
test_sbu_sequential_sampler()
|
||||
test_sbu_exception()
|
||||
test_sbu_visualize(plot=True)
|
||||
test_sbu_decode()
|
Loading…
Reference in New Issue