!18738 [assistant][ops]New operator implementation, include LSUNDataset

Merge pull request !18738 from Wangsong95/lsun_dataset
This commit is contained in:
i-robot 2022-02-23 02:28:40 +00:00 committed by Gitee
commit eeb731ae3e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 1828 additions and 4 deletions

View File

@ -101,6 +101,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
@ -1438,6 +1439,36 @@ LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of LSUNDataset.
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of LSUNDataset.
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of LSUNDataset.
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,

View File

@ -69,6 +69,7 @@
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
@ -439,6 +440,18 @@ PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(LSUNNode, 2, ([](const py::module *m) {
(void)py::class_<LSUNNode, DatasetNode, std::shared_ptr<LSUNNode>>(*m, "LSUNNode",
"to create a LSUNNode")
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
const std::vector<std::string> &classes, bool decode, const py::handle &sampler) {
auto lsun =
std::make_shared<LSUNNode>(dataset_dir, usage, classes, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(lsun->ValidateParams());
return lsun;
}));
}));
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode")

View File

@ -30,6 +30,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
kmnist_op.cc
libri_tts_op.cc
lj_speech_op.cc
lsun_op.cc
mappable_leaf_op.cc
mnist_op.cc
multi30k_op.cc

View File

@ -108,7 +108,7 @@ class ImageFolderOp : public MappableLeafOp {
//// \return Status of the function
Status GetNumClasses(int64_t *num_classes) override;
private:
protected:
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
@ -119,7 +119,7 @@ class ImageFolderOp : public MappableLeafOp {
/// @param std::string & dir - dir to walk all images
/// @param int64_t * cnt - number of non folder files under the current dir
/// @return
Status RecursiveWalkFolder(Path *dir);
virtual Status RecursiveWalkFolder(Path *dir);
/// start walking of all dirs
/// @return

View File

@ -0,0 +1,168 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/lsun_op.h"
#include <fstream>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
LSUNOp::LSUNOp(int32_t num_wkrs, const std::string &file_dir, int32_t queue_size, const std::string &usage,
const std::vector<std::string> &classes, bool do_decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: ImageFolderOp(num_wkrs, file_dir, queue_size, false, do_decode, {}, {}, std::move(data_schema),
std::move(sampler)),
usage_(std::move(usage)),
classes_(std::move(classes)) {}
void LSUNOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\nLSUN directory: " << folder_path_
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
Status LSUNOp::WalkDir(Path *dir, const std::string &usage, const std::vector<std::string> &classes,
const std::unique_ptr<Queue<std::string>> &folder_name_queue, int64_t *num_class) {
RETURN_UNEXPECTED_IF_NULL(dir);
std::vector<std::string> split;
if (usage == "train" || usage == "all") {
split.push_back("_train");
}
if (usage == "valid" || usage == "all") {
split.push_back("_val");
}
uint64_t dirname_offset = dir->ToString().length();
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir);
CHECK_FAIL_RETURN_UNEXPECTED(dir_itr != nullptr,
"Invalid path, failed to open image dir: " + dir->ToString() + ", permission denied.");
std::set<std::string> classes_set;
std::vector<std::string> valid_classes = classes;
if (classes.empty()) {
valid_classes = {"bedroom", "bridge", "church_outdoor", "classroom", "conference_room",
"dining_room", "kitchen", "living_room", "restaurant", "tower"};
}
while (dir_itr->HasNext()) {
std::string subdir = dir_itr->Next().ToString();
for (auto str : split) {
std::string name = subdir.substr(dirname_offset);
for (auto class_name : valid_classes) {
if (name.find(class_name + str) != std::string::npos) {
RETURN_IF_NOT_OK(folder_name_queue->EmplaceBack(name));
classes_set.insert(class_name);
}
}
}
}
if (num_class != nullptr) {
*num_class = classes_set.size();
}
return Status::OK();
}
// A thread that calls WalkFolder
Status LSUNOp::RecursiveWalkFolder(Path *dir) {
RETURN_UNEXPECTED_IF_NULL(dir);
if (usage_ == "test") {
Path folder(folder_path_);
folder = folder / "test";
RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack(folder.ToString().substr(dirname_offset_)));
return Status::OK();
}
RETURN_IF_NOT_OK(WalkDir(dir, usage_, classes_, folder_name_queue_, nullptr));
return Status::OK();
}
Status LSUNOp::CountRowsAndClasses(const std::string &path, const std::string &usage,
const std::vector<std::string> &classes, int64_t *num_rows, int64_t *num_classes) {
Path dir(path);
int64_t row_cnt = 0;
CHECK_FAIL_RETURN_UNEXPECTED(dir.Exists() && dir.IsDirectory(), "Invalid parameter, input dataset path " + path +
" does not exist or is not a directory.");
CHECK_FAIL_RETURN_UNEXPECTED(num_classes != nullptr || num_rows != nullptr,
"[Internal ERROR] num_class and num_rows are null.");
int32_t queue_size = 1024;
auto folder_name_queue = std::make_unique<Queue<std::string>>(queue_size);
RETURN_IF_NOT_OK(WalkDir(&dir, usage, classes, folder_name_queue, num_classes));
// return here if only num_class is needed
RETURN_OK_IF_TRUE(num_rows == nullptr);
while (!folder_name_queue->empty()) {
std::string name;
RETURN_IF_NOT_OK(folder_name_queue->PopFront(&name));
Path subdir(path + name);
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&subdir);
while (dir_itr->HasNext()) {
++row_cnt;
Path subdir_pic = dir_itr->Next();
}
}
*num_rows = row_cnt;
return Status::OK();
}
Status LSUNOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
std::shared_ptr<Tensor> image, label;
uint32_t label_num = static_cast<uint32_t>(pair_ptr->second);
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_num, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
if (decode_ == true) {
Status rc = Decode(image, &image);
if (rc.IsError()) {
std::string err = "Invalid image, " + folder_path_ + (pair_ptr->first) +
" decode failed, the image is broken or permission denied.";
RETURN_STATUS_UNEXPECTED(err);
}
}
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
return Status::OK();
}
// Get number of classes
Status LSUNOp::GetNumClasses(int64_t *num_classes) {
RETURN_UNEXPECTED_IF_NULL(num_classes);
if (num_classes_ > 0) {
*num_classes = num_classes_;
return Status::OK();
}
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, usage_, classes_, nullptr, num_classes));
num_classes_ = *num_classes;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LSUN_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LSUN_OP_H_
#include <memory>
#include <queue>
#include <string>
#include <set>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
namespace mindspore {
namespace dataset {
/// \brief Forward declares.
template <typename T>
class Queue;
using ImageLabelPair = std::shared_ptr<std::pair<std::string, int32_t>>;
using FolderImagesPair = std::shared_ptr<std::pair<std::string, std::queue<ImageLabelPair>>>;
class LSUNOp : public ImageFolderOp {
public:
/// \brief Constructor.
/// \param[in] int32_t num_wkrs num of workers reading images in parallel.
/// \param[in] std::string file_dir dir directory of LSUNDataset.
/// \param[in] int32_t queue_size connector queue size.
/// \param[in] std::string usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] std::vector<std::string> classes Classes list to load.
/// \param[in] bool do_decode decode the images after reading.
/// \param[in] std::unique_ptr<dataschema> data_schema schema of data.
/// \param[in] unique_ptr<Sampler> sampler sampler tells LSUNOp what to read.
LSUNOp(int32_t num_wkrs, const std::string &file_dir, int32_t queue_size, const std::string &usage,
const std::vector<std::string> &classes, bool do_decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor.
~LSUNOp() = default;
/// \brief A print method typically used for debugging.
/// \param[out] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Function to count the number and classes of samples in the LSUN dataset.
/// \param[in] const std::string &path path to the LSUN directory.
/// \param[in] std::string usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] const std::vector<std::string> &classes Classes list to load.
/// \param[out] int64_t *num_rows output arg that will hold the minimum of the actual dataset
/// size and numSamples.
/// \param[out] int64_t *num_classes output arg that will hold the classes num of the actual dataset
/// size and numSamples.
/// \return Status The status code returned.
static Status CountRowsAndClasses(const std::string &path, const std::string &usage,
const std::vector<std::string> &classes, int64_t *num_rows, int64_t *num_classes);
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "LSUNOp"; }
/// \brief Dataset name getter.
/// \param[in] upper Whether to get upper name.
/// \return Dataset name of the current Op.
std::string DatasetName(bool upper = false) const override { return upper ? "LSUN" : "lsun"; }
/// \brief Load a tensor row according to a pair
/// \param[in] row_id id for this tensor row
/// \param[out] trow image & label read into this tensor row
/// \return Status The status code returned
Status LoadTensorRow(row_id_type row_id, TensorRow *trow) override;
/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function
Status GetNumClasses(int64_t *num_classes) override;
private:
/// \brief Base-class override for RecursiveWalkFolder
/// \param[in] std::string & dir dir to lsun dataset.
/// \return Status of the function
Status RecursiveWalkFolder(Path *dir) override;
/// \brief Function to save the path list to folder_paths
/// \param[in] std::string & dir dir to lsun dataset.
/// \param[in] std::string usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] const std::vector<std::string> &classes Classes list to load.
/// \param[out] std::unique_ptr<Queue<std::string>> &folder_name_queue output arg that will hold the path list.
/// \param[out] int64_t *num_class the number of classes
/// \return Status of the function
static Status WalkDir(Path *dir, const std::string &usage, const std::vector<std::string> &classes,
const std::unique_ptr<Queue<std::string>> &folder_name_queue, int64_t *num_class);
std::string usage_;
std::vector<std::string> classes_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LSUN_OP_H_

View File

@ -105,6 +105,7 @@ constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset";
constexpr char kKMnistNode[] = "KMnistDataset";
constexpr char kLibriTTSNode[] = "LibriTTSDataset";
constexpr char kLJSpeechNode[] = "LJSpeechDataset";
constexpr char kLSUNNode[] = "LSUNDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";

View File

@ -31,6 +31,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
kmnist_node.cc
libri_tts_node.cc
lj_speech_node.cc
lsun_node.cc
manifest_node.cc
minddata_node.cc
mnist_node.cc

View File

@ -0,0 +1,128 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/lsun_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
LSUNNode::LSUNNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &classes,
bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache = nullptr)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
classes_(classes),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> LSUNNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<LSUNNode>(dataset_dir_, usage_, classes_, decode_, sampler, cache_);
return node;
}
void LSUNNode::Print(std::ostream &out) const {
out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
}
Status LSUNNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("LSUNDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("LSUNDataset", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", usage_, {"train", "test", "valid", "all"}));
for (auto class_name : classes_) {
RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", class_name,
{"bedroom", "bridge", "church_outdoor", "classroom", "conference_room",
"dining_room", "kitchen", "living_room", "restaurant", "tower"}));
}
return Status::OK();
}
Status LSUNNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
// This arg is exist in LSUNOp, but not externalized (in Python API).
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<LSUNOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, classes_, decode_,
std::move(schema), std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
RETURN_UNEXPECTED_IF_NULL(node_ops);
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status LSUNNode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status LSUNNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(LSUNOp::CountRowsAndClasses(dataset_dir_, usage_, classes_, &num_rows, nullptr));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status LSUNNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["classes"] = classes_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LSUN_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LSUN_NODE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class LSUNNode
/// \brief A Dataset derived class to represent LSUN dataset
class LSUNNode : public MappableSourceNode {
public:
/// \brief Constructor
/// \param[in] dataset_dir Dataset directory of LSUNDataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Choose specified lsun classes to load.
/// \param[in] do_decode Decode the images after reading.
/// \param[in] sampler Tells LSUNOp what to read.
/// \param[in] cache Tensor cache to use.
LSUNNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &classes,
bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~LSUNNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kLSUNNode; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \param[in] share_id num of the shards of the Dataset
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
const std::vector<std::string> &Classes() const { return classes_; }
bool Decode() const { return decode_; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Specify sampler.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
std::vector<std::string> classes_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LSUN_NODE_H_

View File

@ -388,7 +388,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced.
/// \param[in] project_columns A list of column names to project.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \param[in] callbacks List of Dataset callbacks to be called.
/// \return Shared pointer to the current Dataset.
/// \par Example
@ -3405,6 +3405,113 @@ inline std::shared_ptr<LJSpeechDataset> MS_API LJSpeech(const std::string &datas
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
/// \class LSUNDataset
/// \brief A source dataset for reading LSUN datast.
class MS_API LSUNDataset : public Dataset {
public:
/// \brief Constructor of LSUNDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Classes list to load.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of LSUNDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Classes list to load.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of LSUNDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Classes list to load.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &classes, bool decode, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of LSUNDataset.
~LSUNDataset() = default;
};
/// \brief Function to create a LSUNDataset.
/// \note The generated dataset has two columns "image" and "label".
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all` (Default=`all`).
/// \param[in] classes Classes list to load, such as 'bedroom', 'classroom' (Default={}, means load all classes).
/// \param[in] decode Decode the images after reading (Default=false).
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (Default = RandomSampler()).
/// \param[in] cache Tensor cache to use (Default=nullptr, which means no cache is used).
/// \return Shared pointer to the current LSUNDataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/lsun_dataset_directory";
/// std::shared_ptr<Dataset> ds = LSUN(folder_path, "all");
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In LSUNDataset, each data dictionary has keys "image" and "label" */
/// auto image = row["image"];
/// \endcode
inline std::shared_ptr<LSUNDataset> MS_API
LSUN(const std::string &dataset_dir, const std::string &usage = "all", const std::vector<std::string> &classes = {},
bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LSUNDataset>(StringToChar(dataset_dir), StringToChar(usage), VectorStringToChar(classes),
decode, sampler, cache);
}
/// \brief Function to create a LSUNDataset.
/// \note The generated dataset has two columns "image" and "label".
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Classes list to load.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (Default=nullptr, which means no cache is used).
/// \return Shared pointer to the current LSUNDataset.
inline std::shared_ptr<LSUNDataset> MS_API LSUN(const std::string &dataset_dir, const std::string &usage,
const std::vector<std::string> &classes, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LSUNDataset>(StringToChar(dataset_dir), StringToChar(usage), VectorStringToChar(classes),
decode, sampler, cache);
}
/// \brief Function to create a LSUNDataset.
/// \note The generated dataset has two columns "image" and "label".
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Dataset splits of LSUN, can be `train`, `valid`, `test` or `all`.
/// \param[in] classes Classes list to load.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (Default=nullptr, which means no cache is used).
/// \return Shared pointer to the current LSUNDataset.
inline std::shared_ptr<LSUNDataset> MS_API LSUN(const std::string &dataset_dir, const std::string &usage,
const std::vector<std::string> &classes, bool decode,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LSUNDataset>(StringToChar(dataset_dir), StringToChar(usage), VectorStringToChar(classes),
decode, sampler, cache);
}
/// \class ManifestDataset
/// \brief A source dataset for reading and parsing Manifest dataset.
class MS_API ManifestDataset : public Dataset {

View File

@ -53,6 +53,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
friend class KMnistDataset;
friend class LibriTTSDataset;
friend class LJSpeechDataset;
friend class LSUNDataset;
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;

View File

@ -50,6 +50,7 @@ __all__ = ["Caltech101Dataset", # Vision
"FlickrDataset", # Vision
"Flowers102Dataset", # Vision
"ImageFolderDataset", # Vision
"LSUNDataset", # Vision
"KMnistDataset", # Vision
"ManifestDataset", # Vision
"MnistDataset", # Vision

View File

@ -38,7 +38,7 @@ from .validators import check_imagefolderdataset, \
check_usps_dataset, check_div2k_dataset, check_random_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_svhn_dataset, check_stl10_dataset, check_semeion_dataset, \
check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset
check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset, check_lsun_dataset
from ..core.validator_helpers import replace_none
@ -2419,6 +2419,136 @@ class KMnistDataset(MappableDataset, VisionBaseDataset):
return cde.KMnistNode(self.dataset_dir, self.usage, self.sampler)
class LSUNDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses the LSUN dataset.
The generated dataset has two columns: :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test`, `valid` or `all`
(default=None, will be set to `all`).
classes(Union[str, list[str]], optional): Choose the specific classes to load (default=None, means loading
all classes in root directory).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, 'num_samples' reflects
the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).
Raises:
RuntimeError: If 'sampler' and 'shuffle' are specified at the same time.
RuntimeError: If 'sampler' and sharding are specified at the same time.
RuntimeError: If 'num_shards' is specified but 'shard_id' is None.
RuntimeError: If 'shard_id' is specified but 'num_shards' is None.
ValueError: If 'shard_id' is invalid (< 0 or >= num_shards).
ValueError: If 'usage' or 'classes' is invalid (not in specific types).
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> lsun_dataset_dir = "/path/to/lsun_dataset_directory"
>>>
>>> # 1) Read all samples (image files) in lsun_dataset_dir with 8 threads
>>> dataset = ds.LSUNDataset(dataset_dir=lsun_dataset_dir,
... num_parallel_workers=8)
>>>
>>> # 2) Read all train samples (image files) from folder "bedroom" and "classroom"
>>> dataset = ds.LSUNDataset(dataset_dir=lsun_dataset_dir, usage="train",
... classes=["bedroom", "classroom"])
About LSUN dataset:
The LSUN dataset accesses the effectiveness of this cascading procedure and enables further progress
in visual recognition research.
The LSUN dataset contains around one million labeled images for each of 10 scene categories
and 20 object categories. The author experimented with training popular convolutional networks and found
that they achieved substantial performance gains when trained on this dataset.
You can unzip the original LSUN dataset files into this directory structure using official data.py and
read by MindSpore's API.
.. code-block::
.
lsun_dataset_directory
test
...
bedroom_train
1_1.jpg
1_2.jpg
bedroom_val
...
classroom_train
...
classroom_val
...
Citation:
.. code-block::
article{yu15lsun,
title={LSUN: Construction of a Large-scale Image Dataset using Deep Learning with Humans in the Loop},
author={Yu, Fisher and Zhang, Yinda and Song, Shuran and Seff, Ari and Xiao, Jianxiong},
journal={arXiv preprint arXiv:1506.03365},
year={2015}
}
"""
@check_lsun_dataset
def __init__(self, dataset_dir, usage=None, classes=None, num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
self.classes = replace_none(classes, [])
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.LSUNNode(self.dataset_dir, self.usage, self.classes, self.decode, self.sampler)
class ManifestDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading images from a Manifest file.

View File

@ -258,6 +258,47 @@ def check_iwslt2017_dataset(method):
return new_method
def check_lsun_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(LSUNDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_list = ['classes']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_list, param_dict, list)
categories = [
'bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower'
]
classes = param_dict.get('classes')
if classes is not None:
for class_name in classes:
check_valid_str(class_name, categories, "classes")
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_mnist_cifar_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""

View File

@ -40,6 +40,7 @@ SET(DE_UT_SRCS
c_api_dataset_kmnist_test.cc
c_api_dataset_libri_tts.cc
c_api_dataset_lj_speech_test.cc
c_api_dataset_lsun_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_multi30k_test.cc

View File

@ -0,0 +1,366 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNTrainDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNTrainDataset.";
// Create a LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "train", {}, false, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNValidDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNValidDataset.";
// Create a LSUN Validation Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "valid", {}, false, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNTestDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNTestDataset.";
// Create a LSUN Test Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "test", {}, false, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNAllDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNAllDataset.";
// Create a LSUN Test Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "all", {}, false, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNClassesDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNClassesDataset.";
// Create a LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds =
LSUN(folder_path, "train", {"bedroom", "classroom"}, false, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNDatasetWithPipeline.";
// Create two LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds1 =
LSUN(folder_path, "train", {"bedroom", "classroom"}, false, std::make_shared<RandomSampler>(false, 5));
std::shared_ptr<Dataset> ds2 =
LSUN(folder_path, "train", {"bedroom", "classroom"}, false, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNGetDatasetSize.";
// Create a LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "train", {}, false);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 2);
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNClassesGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetLSUNClassesDatasetSize.";
// Create a LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "train", {"bedroom", "classroom"}, false);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 2);
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset.
/// Expectation: get correct lsun dataset.
TEST_F(MindDataTestPipeline, TestLSUNDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNDatasetGetters.";
// Create a LSUN Train Dataset.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "train", {}, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 2);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, 2);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset with wrong folder path.
/// Expectation: throw exception correctly.
TEST_F(MindDataTestPipeline, TestLSUNDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNDatasetFail.";
// Create a LSUN Dataset in which th folder path is invalid.
std::shared_ptr<Dataset> ds = LSUN("", "train", {}, false, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid LSUN input, state folder path is invalid.
EXPECT_EQ(iter, nullptr);
}
/// Feature: LSUNDataset.
/// Description: test LSUNDataset with null sampler.
/// Expectation: throw exception correctly.
TEST_F(MindDataTestPipeline, TestLSUNDatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLSUNDatasetWithNullSamplerFail.";
// Create a LSUN Dataset in which th Sampler is not provided.
std::string folder_path = datasets_root_path_ + "/testLSUN";
std::shared_ptr<Dataset> ds = LSUN(folder_path, "train", {}, false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid LSUN input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

View File

@ -0,0 +1,609 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test LSUN dataset operators
"""
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testLSUN"
def test_lsun_basic():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case basic")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_lsun_num_samples():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case num_samples")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.LSUNDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.LSUNDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
def test_lsun_num_shards():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case numShards")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=1)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lsun_shard_id():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case withShardID")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=0)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lsun_no_shuffle():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case noShuffle")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_lsun_extra_shuffle():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case extra_shuffle")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, shuffle=True)
data1 = data1.shuffle(buffer_size=5)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_lsun_decode():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case decode")
# define parameters
repeat_count = 1
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, decode=True)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_sequential_sampler():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case SequentialSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.SequentialSampler(num_samples=10)
data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
data1 = data1.repeat(repeat_count)
result = []
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
result.append(item["label"])
num_iter += 1
assert num_iter == 2
logger.info("Result: {}".format(result))
def test_random_sampler():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case RandomSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.RandomSampler()
data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_distributed_sampler():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case DistributedSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.DistributedSampler(2, 1)
data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 1
def test_pk_sampler():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case PKSampler")
# define parameters
repeat_count = 1
# apply dataset operations
sampler = ds.PKSampler(1)
data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_chained_sampler():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
# Create chained sampler, random and sequential
sampler = ds.RandomSampler()
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
# Create LSUNDataset with sampler
data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
data1 = data1.repeat(count=3)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 6
# Verify number of iterations
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_lsun_test_dataset():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case usage")
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, usage="test", num_samples=8)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 1
def test_lsun_valid_dataset():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case usage")
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, usage="valid", num_samples=8)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lsun_train_dataset():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case usage")
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, usage="train", num_samples=8)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lsun_all_dataset():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case usage")
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, usage="all", num_samples=8)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_lsun_classes():
"""
Feature: LSUN
Description: test classes of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case usage")
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, usage="train", classes=["bedroom"], num_samples=8)
num_iter = 0
# each data is a dictionary
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 1
def test_lsun_zip():
"""
Feature: LSUN
Description: test basic usage of LSUN
Expectation: the dataset is as expected
"""
logger.info("Test Case zip")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.LSUNDataset(DATA_DIR, num_samples=10)
data2 = ds.LSUNDataset(DATA_DIR, num_samples=10)
data1 = data1.repeat(repeat_count)
# rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
data3 = ds.zip((data1, data2))
num_iter = 0
# each data is a dictionary
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_lsun_exception():
"""
Feature: LSUN
Description: test error cases for LSUN
Expectation: throw exception correctly
"""
logger.info("Test lsun exception")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.LSUNDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.LSUNDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.LSUNDataset(DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.LSUNDataset(DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.LSUNDataset(DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.LSUNDataset(DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id="0")
def test_lsun_exception_map():
"""
Feature: LSUN
Description: test error cases for LSUN
Expectation: throw exception correctly
"""
logger.info("Test lsun exception map")
def exception_func(item):
raise Exception("Error occur!")
def exception_func2(image, label):
raise Exception("Error occur!")
try:
data = ds.LSUNDataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.LSUNDataset(DATA_DIR)
data = data.map(operations=exception_func2,
input_columns=["image", "label"],
output_columns=["image", "label", "label1"],
column_order=["image", "label", "label1"],
num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.LSUNDataset(DATA_DIR)
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
if __name__ == '__main__':
test_lsun_basic()
test_lsun_num_samples()
test_sequential_sampler()
test_random_sampler()
test_distributed_sampler()
test_pk_sampler()
test_lsun_num_shards()
test_lsun_shard_id()
test_lsun_no_shuffle()
test_lsun_extra_shuffle()
test_lsun_decode()
test_lsun_test_dataset()
test_lsun_valid_dataset()
test_lsun_train_dataset()
test_lsun_all_dataset()
test_lsun_classes()
test_lsun_zip()
test_lsun_exception()
test_lsun_exception_map()