!5605 Introduce usage flag to MNIST and CIFAR dataset
Merge pull request !5605 from ZiruiWu/add_usage_to_cifar_mnist_coco
This commit is contained in:
commit
ea94756839
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include <unordered_set>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
@ -132,26 +132,28 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
|
|||
}
|
||||
|
||||
// Function to create a CelebADataset.
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler, bool decode,
|
||||
const std::set<std::string> &extensions) {
|
||||
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
|
||||
auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a Cifar10Dataset.
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a Cifar100Dataset.
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -217,8 +219,9 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
|
|||
#endif
|
||||
|
||||
// Function to create a MnistDataset.
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -244,10 +247,10 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Function to create a VOCDataset.
|
||||
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler);
|
||||
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -727,6 +730,10 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) {
|
||||
return valid_strings.find(str) != valid_strings.end();
|
||||
}
|
||||
|
||||
// Helper function to validate dataset input/output column parameter
|
||||
bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
|
||||
const std::vector<std::string> &columns) {
|
||||
|
@ -802,29 +809,14 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumDataset::Build() {
|
|||
}
|
||||
|
||||
// Constructor for CelebADataset
|
||||
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
|
||||
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::set<std::string> &extensions)
|
||||
: dataset_dir_(dataset_dir),
|
||||
dataset_type_(dataset_type),
|
||||
sampler_(sampler),
|
||||
decode_(decode),
|
||||
extensions_(extensions) {}
|
||||
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {}
|
||||
|
||||
bool CelebADataset::ValidateParams() {
|
||||
if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) {
|
||||
return false;
|
||||
}
|
||||
if (!ValidateDatasetSampler("CelebADataset", sampler_)) {
|
||||
return false;
|
||||
}
|
||||
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
|
||||
auto iter = dataset_type_list.find(dataset_type_);
|
||||
if (iter == dataset_type_list.end()) {
|
||||
MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return ValidateDatasetDirParam("CelebADataset", dataset_dir_) && ValidateDatasetSampler("CelebADataset", sampler_) &&
|
||||
ValidateStringValue(usage_, {"all", "train", "valid", "test"});
|
||||
}
|
||||
|
||||
// Function to build CelebADataset
|
||||
|
@ -839,17 +831,20 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, dataset_type_, extensions_, std::move(schema),
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Constructor for Cifar10Dataset
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
bool Cifar10Dataset::ValidateParams() {
|
||||
return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && ValidateDatasetSampler("Cifar10Dataset", sampler_);
|
||||
return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) &&
|
||||
ValidateDatasetSampler("Cifar10Dataset", sampler_) &&
|
||||
ValidateStringValue(usage_, {"train", "test", "all", ""});
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar10
|
||||
|
@ -864,19 +859,21 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Constructor for Cifar100Dataset
|
||||
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
bool Cifar100Dataset::ValidateParams() {
|
||||
return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) &&
|
||||
ValidateDatasetSampler("Cifar100Dataset", sampler_);
|
||||
ValidateDatasetSampler("Cifar100Dataset", sampler_) &&
|
||||
ValidateStringValue(usage_, {"train", "test", "all", ""});
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar100
|
||||
|
@ -893,7 +890,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
|
@ -1360,11 +1357,12 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestDataset::Build() {
|
|||
}
|
||||
#endif
|
||||
|
||||
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
bool MnistDataset::ValidateParams() {
|
||||
return ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_);
|
||||
return ValidateStringValue(usage_, {"train", "test", "all", ""}) &&
|
||||
ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
||||
|
@ -1378,8 +1376,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
std::move(schema), std::move(sampler_->Build())));
|
||||
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
@ -1570,12 +1568,12 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Constructor for VOCDataset
|
||||
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir),
|
||||
task_(task),
|
||||
mode_(mode),
|
||||
usage_(usage),
|
||||
class_index_(class_indexing),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
@ -1594,15 +1592,15 @@ bool VOCDataset::ValidateParams() {
|
|||
MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task.";
|
||||
return false;
|
||||
}
|
||||
Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
|
||||
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
|
||||
MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
return false;
|
||||
}
|
||||
} else if (task_ == "Detection") {
|
||||
Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
|
||||
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
|
||||
MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
@ -1641,7 +1639,7 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
|
|||
}
|
||||
|
||||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
node_ops.push_back(voc_op);
|
||||
return node_ops;
|
||||
|
|
|
@ -41,9 +41,9 @@ namespace dataset {
|
|||
|
||||
PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
|
||||
.def_static("get_num_rows", [](const std::string &dir, const std::string &usage, bool isCifar10) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, usage, isCifar10, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
@ -131,9 +131,9 @@ PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir) {
|
||||
.def_static("get_num_rows", [](const std::string &dir, const std::string &usage) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, usage, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
|
|
@ -1354,25 +1354,14 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
|
||||
Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
if (args["dataset_dir"].is_none()) {
|
||||
std::string err_msg = "Error: No dataset path specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["task"].is_none()) {
|
||||
std::string err_msg = "Error: No task specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["mode"].is_none()) {
|
||||
std::string err_msg = "Error: No mode specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!args["dataset_dir"].is_none(), "Error: No dataset path specified.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified.");
|
||||
|
||||
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
(void)builder->SetTask(ToString(args["task"]));
|
||||
(void)builder->SetMode(ToString(args["mode"]));
|
||||
(void)builder->SetUsage(ToString(args["usage"]));
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
|
@ -1461,6 +1450,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
|
|||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1495,6 +1486,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
|
|||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1608,6 +1601,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1645,8 +1640,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "extensions") {
|
||||
(void)builder->SetExtensions(ToStringSet(value));
|
||||
} else if (key == "dataset_type") {
|
||||
(void)builder->SetDatasetType(ToString(value));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr)
|
|||
|
||||
Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
|
||||
MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << ".";
|
||||
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << ".";
|
||||
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_usage_.c_str() << ".";
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
const int64_t num_samples = 0;
|
||||
|
@ -51,8 +51,8 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
|
|||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
builder_op_connector_size_, builder_decode_, builder_dataset_type_,
|
||||
builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
builder_op_connector_size_, builder_decode_, builder_usage_, builder_extensions_,
|
||||
std::move(builder_schema_), std::move(builder_sampler_));
|
||||
if (*op == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null");
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
|
|||
}
|
||||
|
||||
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
|
||||
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
|
||||
bool decode, const std::string &usage, const std::set<std::string> &exts,
|
||||
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -78,7 +78,7 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
|
|||
extensions_(exts),
|
||||
data_schema_(std::move(schema)),
|
||||
num_rows_in_attr_file_(0),
|
||||
dataset_type_(dataset_type) {
|
||||
usage_(usage) {
|
||||
attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size);
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ Status CelebAOp::ParseAttrFile() {
|
|||
std::vector<std::string> image_infos;
|
||||
image_infos.reserve(oc_queue_size_);
|
||||
while (getline(attr_file, image_info)) {
|
||||
if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) {
|
||||
if ((image_info.empty()) || (usage_ != "all" && !CheckDatasetTypeValid())) {
|
||||
continue;
|
||||
}
|
||||
image_infos.push_back(image_info);
|
||||
|
@ -179,11 +179,11 @@ bool CelebAOp::CheckDatasetTypeValid() {
|
|||
return false;
|
||||
}
|
||||
// train:0, valid=1, test=2
|
||||
if (dataset_type_ == "train" && (type == 0)) {
|
||||
if (usage_ == "train" && (type == 0)) {
|
||||
return true;
|
||||
} else if (dataset_type_ == "valid" && (type == 1)) {
|
||||
} else if (usage_ == "valid" && (type == 1)) {
|
||||
return true;
|
||||
} else if (dataset_type_ == "test" && (type == 2)) {
|
||||
} else if (usage_ == "test" && (type == 2)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -109,10 +109,10 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string dataset_type: type to be read
|
||||
// @param const std::string usage: type to be read
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDatasetType(const std::string &dataset_type) {
|
||||
builder_dataset_type_ = dataset_type;
|
||||
Builder &SetUsage(const std::string &usage) {
|
||||
builder_usage_ = usage;
|
||||
return *this;
|
||||
}
|
||||
// Check validity of input args
|
||||
|
@ -133,7 +133,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
std::set<std::string> builder_extensions_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
std::string builder_dataset_type_;
|
||||
std::string builder_usage_;
|
||||
};
|
||||
|
||||
// Constructor
|
||||
|
@ -143,12 +143,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// @param int32_t queueSize - connector queue size
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
|
||||
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
|
||||
const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
|
||||
const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
|
||||
~CelebAOp() override = default;
|
||||
|
||||
// Main Loop of CelebaOp
|
||||
// Main Loop of CelebAOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
|
@ -177,7 +177,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const { return "CelebAOp"; }
|
||||
std::string Name() const override { return "CelebAOp"; }
|
||||
|
||||
private:
|
||||
// Called first when function is called
|
||||
|
@ -232,7 +232,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
WaitPost wp_;
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
|
||||
std::string dataset_type_;
|
||||
std::string usage_;
|
||||
std::ifstream partition_file_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -18,15 +18,16 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -36,7 +37,7 @@ constexpr uint32_t kCifarImageChannel = 3;
|
|||
constexpr uint32_t kCifarBlockImageNum = 5;
|
||||
constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel;
|
||||
|
||||
CifarOp::Builder::Builder() : sampler_(nullptr) {
|
||||
CifarOp::Builder::Builder() : sampler_(nullptr), usage_("") {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -65,23 +66,27 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
|
|||
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
|
||||
}
|
||||
|
||||
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
|
||||
*ptr = std::make_shared<CifarOp>(cifar_type_, usage_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
|
||||
std::move(schema_), std::move(sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CifarOp::Builder::SanityCheck() {
|
||||
const std::set<std::string> valid = {"test", "train", "all", ""};
|
||||
Path dir(dir_);
|
||||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : "";
|
||||
err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : "";
|
||||
err_msg += valid.find(usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
|
||||
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_works, queue_size, std::move(sampler)),
|
||||
cifar_type_(type),
|
||||
usage_(usage),
|
||||
rows_per_buffer_(rows_per_buf),
|
||||
folder_path_(file_dir),
|
||||
data_schema_(std::move(data_schema)),
|
||||
|
@ -258,21 +263,32 @@ Status CifarOp::ReadCifarBlockDataAsync() {
|
|||
}
|
||||
|
||||
Status CifarOp::ReadCifar10BlockData() {
|
||||
// CIFAR 10 has 6 bin files. data_batch_1.bin ... data_batch_5.bin and 1 test_batch.bin file
|
||||
// each of the file has exactly 10K images and labels and size is 30,730 KB
|
||||
// each image has the dimension of 32 x 32 x 3 = 3072 plus 1 label (label has 10 classes) so each row has 3073 bytes
|
||||
constexpr uint32_t num_cifar10_records = 10000;
|
||||
uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M
|
||||
std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0);
|
||||
for (auto &file : cifar_files_) {
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
if (!in.is_open()) {
|
||||
std::string err_msg = file + " can not be opened.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
// check the validity of the file path
|
||||
Path file_path(file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
|
||||
std::string file_name = file_path.Basename();
|
||||
|
||||
if (usage_ == "train") {
|
||||
if (file_name.find("data_batch") == std::string::npos) continue;
|
||||
} else if (usage_ == "test") {
|
||||
if (file_name.find("test_batch") == std::string::npos) continue;
|
||||
} else { // get all the files that contain the word batch, aka any cifar 100 files
|
||||
if (file_name.find("batch") == std::string::npos) continue;
|
||||
}
|
||||
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
|
||||
|
||||
for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) {
|
||||
(void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char));
|
||||
if (in.fail()) {
|
||||
RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file);
|
||||
(void)cifar_raw_data_block_->EmplaceBack(image_data);
|
||||
}
|
||||
in.close();
|
||||
|
@ -283,15 +299,21 @@ Status CifarOp::ReadCifar10BlockData() {
|
|||
}
|
||||
|
||||
Status CifarOp::ReadCifar100BlockData() {
|
||||
// CIFAR 100 has 2 bin files. train.bin (60K imgs) 153,700KB and test.bin (30,740KB) (10K imgs)
|
||||
// each img has two labels. Each row then is 32 * 32 *5 + 2 = 3,074 Bytes
|
||||
uint32_t num_cifar100_records = 0; // test:10000, train:50000
|
||||
uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M
|
||||
std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0);
|
||||
for (auto &file : cifar_files_) {
|
||||
int pos = file.find_last_of('/');
|
||||
if (pos == std::string::npos) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path");
|
||||
}
|
||||
std::string file_name(file.substr(pos + 1));
|
||||
// check the validity of the file path
|
||||
Path file_path(file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
|
||||
std::string file_name = file_path.Basename();
|
||||
|
||||
// if usage is train/test, get only these 2 files
|
||||
if (usage_ == "train" && file_name.find("train") == std::string::npos) continue;
|
||||
if (usage_ == "test" && file_name.find("test") == std::string::npos) continue;
|
||||
|
||||
if (file_name.find("test") != std::string::npos) {
|
||||
num_cifar100_records = 10000;
|
||||
} else if (file_name.find("train") != std::string::npos) {
|
||||
|
@ -301,15 +323,11 @@ Status CifarOp::ReadCifar100BlockData() {
|
|||
}
|
||||
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
if (!in.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED(file + " can not be opened.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
|
||||
|
||||
for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) {
|
||||
(void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char));
|
||||
if (in.fail()) {
|
||||
RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file);
|
||||
(void)cifar_raw_data_block_->EmplaceBack(image_data);
|
||||
}
|
||||
in.close();
|
||||
|
@ -319,26 +337,20 @@ Status CifarOp::ReadCifar100BlockData() {
|
|||
}
|
||||
|
||||
Status CifarOp::GetCifarFiles() {
|
||||
// Initialize queue to hold the file names
|
||||
const std::string kExtension = ".bin";
|
||||
Path dataset_directory(folder_path_);
|
||||
auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory);
|
||||
Path dir_path(folder_path_);
|
||||
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
|
||||
if (dirIt) {
|
||||
while (dirIt->hasNext()) {
|
||||
Path file = dirIt->next();
|
||||
std::string filename = file.toString();
|
||||
if (filename.find(kExtension) != std::string::npos) {
|
||||
cifar_files_.push_back(filename);
|
||||
MS_LOG(INFO) << "Cifar operator found file at " << filename << ".";
|
||||
if (file.Extension() == kExtension) {
|
||||
cifar_files_.push_back(file.toString());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string err_msg = "Unable to open directory " + dataset_directory.toString();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
if (cifar_files_.size() == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
|
||||
RETURN_STATUS_UNEXPECTED("Unable to open directory " + dir_path.toString());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!cifar_files_.empty(), "No .bin files found under " + folder_path_);
|
||||
std::sort(cifar_files_.begin(), cifar_files_.end());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -378,9 +390,8 @@ Status CifarOp::ParseCifarData() {
|
|||
num_rows_ = cifar_image_label_pairs_.size();
|
||||
if (num_rows_ == 0) {
|
||||
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
||||
std::string err_msg = "There is no valid data matching the dataset API " + api +
|
||||
".Please check file path or dataset API validation first.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
RETURN_STATUS_UNEXPECTED("There is no valid data matching the dataset API " + api +
|
||||
".Please check file path or dataset API validation first.");
|
||||
}
|
||||
cifar_raw_data_block_->Reset();
|
||||
return Status::OK();
|
||||
|
@ -403,46 +414,51 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) {
|
||||
Status CifarOp::CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
|
||||
std::shared_ptr<CifarOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).SetUsage(usage).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->GetCifarFiles());
|
||||
if (op->cifar_type_ == kCifar10) {
|
||||
constexpr int64_t num_cifar10_records = 10000;
|
||||
for (auto &file : op->cifar_files_) {
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
if (!in.is_open()) {
|
||||
std::string err_msg = file + " can not be opened.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
Path file_path(file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
|
||||
std::string file_name = file_path.Basename();
|
||||
|
||||
if (op->usage_ == "train") {
|
||||
if (file_name.find("data_batch") == std::string::npos) continue;
|
||||
} else if (op->usage_ == "test") {
|
||||
if (file_name.find("test_batch") == std::string::npos) continue;
|
||||
} else { // get all the files that contain the word batch, aka any cifar 100 files
|
||||
if (file_name.find("batch") == std::string::npos) continue;
|
||||
}
|
||||
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
|
||||
*count = *count + num_cifar10_records;
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
int64_t num_cifar100_records = 0;
|
||||
for (auto &file : op->cifar_files_) {
|
||||
size_t pos = file.find_last_of('/');
|
||||
if (pos == std::string::npos) {
|
||||
std::string err_msg = "Invalid cifar100 file path";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::string file_name;
|
||||
if (file.size() > 0)
|
||||
file_name = file.substr(pos + 1);
|
||||
else
|
||||
RETURN_STATUS_UNEXPECTED("Invalid string length!");
|
||||
Path file_path(file);
|
||||
std::string file_name = file_path.Basename();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
|
||||
|
||||
if (op->usage_ == "train" && file_path.Basename().find("train") == std::string::npos) continue;
|
||||
if (op->usage_ == "test" && file_path.Basename().find("test") == std::string::npos) continue;
|
||||
|
||||
if (file_name.find("test") != std::string::npos) {
|
||||
num_cifar100_records = 10000;
|
||||
num_cifar100_records += 10000;
|
||||
} else if (file_name.find("train") != std::string::npos) {
|
||||
num_cifar100_records = 50000;
|
||||
num_cifar100_records += 50000;
|
||||
}
|
||||
std::ifstream in(file, std::ios::binary);
|
||||
if (!in.is_open()) {
|
||||
std::string err_msg = file + " can not be opened.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
|
||||
}
|
||||
*count = num_cifar100_records;
|
||||
return Status::OK();
|
||||
|
|
|
@ -83,15 +83,23 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Setter method
|
||||
// @param const std::string & dir
|
||||
// @return
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetCifarDir(const std::string &dir) {
|
||||
dir_ = dir;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string &usage
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetUsage(const std::string &usage) {
|
||||
usage_ = usage;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string & dir
|
||||
// @return
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetCifarType(const bool cifar10) {
|
||||
if (cifar10) {
|
||||
cifar_type_ = kCifar10;
|
||||
|
@ -112,6 +120,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
std::string dir_;
|
||||
std::string usage_;
|
||||
int32_t num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t op_connect_size_;
|
||||
|
@ -122,13 +131,15 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Constructor
|
||||
// @param CifarType type - Cifar10 or Cifar100
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
|
||||
// @param uint32_t numWorks - Num of workers reading images in parallel
|
||||
// @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer
|
||||
// @param std::string - dir directory of cifar dataset
|
||||
// @param uint32_t - queueSize - connector queue size
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
|
||||
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
// Destructor.
|
||||
~CifarOp() = default;
|
||||
|
||||
|
@ -153,7 +164,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
|
||||
// @param count output arg that will hold the actual dataset size
|
||||
// @return
|
||||
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count);
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
/// \param[in] p Pointer to the NodePass to be accepted
|
||||
|
@ -224,7 +235,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
std::unique_ptr<DataSchema> data_schema_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
|
||||
const std::string usage_; // can only be either "train" or "test"
|
||||
WaitPost wp_;
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
|
@ -32,7 +33,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
|
|||
const int32_t kMnistImageRows = 28;
|
||||
const int32_t kMnistImageCols = 28;
|
||||
|
||||
MnistOp::Builder::Builder() : builder_sampler_(nullptr) {
|
||||
MnistOp::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("") {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -52,22 +53,25 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
|
|||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
*ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
*ptr = std::make_shared<MnistOp>(builder_usage_, builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::Builder::SanityCheck() {
|
||||
const std::set<std::string> valid = {"test", "train", "all", ""};
|
||||
Path dir(builder_dir_);
|
||||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : "";
|
||||
err_msg += valid.find(builder_usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
usage_(usage),
|
||||
buf_cnt_(0),
|
||||
row_cnt_(0),
|
||||
folder_path_(folder_path),
|
||||
|
@ -226,9 +230,7 @@ Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
|
|||
Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) {
|
||||
uint32_t res = 0;
|
||||
reader->read(reinterpret_cast<char *>(&res), 4);
|
||||
if (reader->fail()) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!reader->fail(), "Failed to read 4 bytes from file");
|
||||
*result = SwapEndian(res);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -239,15 +241,12 @@ uint32_t MnistOp::SwapEndian(uint32_t val) const {
|
|||
}
|
||||
|
||||
Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) {
|
||||
if (image_reader->is_open() == false) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_reader->is_open(), "Cannot open mnist image file: " + file_name);
|
||||
int64_t image_len = image_reader->seekg(0, std::ios::end).tellg();
|
||||
(void)image_reader->seekg(0, std::ios::beg);
|
||||
// The first 16 bytes of the image file are type, number, row and column
|
||||
if (image_len < 16) {
|
||||
RETURN_STATUS_UNEXPECTED("Mnist file is corrupted.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_len >= 16, "Mnist file is corrupted.");
|
||||
|
||||
uint32_t magic_number;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber,
|
||||
|
@ -260,35 +259,25 @@ Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_re
|
|||
uint32_t cols;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols));
|
||||
// The image size of the Mnist dataset is fixed at [28,28]
|
||||
if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) {
|
||||
RETURN_STATUS_UNEXPECTED("Wrong shape of image.");
|
||||
}
|
||||
if ((image_len - 16) != num_items * rows * cols) {
|
||||
RETURN_STATUS_UNEXPECTED("Wrong number of image.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((rows == kMnistImageRows) && (cols == kMnistImageCols), "Wrong shape of image.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((image_len - 16) == num_items * rows * cols, "Wrong number of image.");
|
||||
*num_images = num_items;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
|
||||
if (label_reader->is_open() == false) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(), "Cannot open mnist label file: " + file_name);
|
||||
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
|
||||
(void)label_reader->seekg(0, std::ios::beg);
|
||||
// The first 8 bytes of the image file are type and number
|
||||
if (label_len < 8) {
|
||||
RETURN_STATUS_UNEXPECTED("Mnist file is corrupted.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 8, "Mnist file is corrupted.");
|
||||
uint32_t magic_number;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber,
|
||||
"This is not the mnist label file: " + file_name);
|
||||
uint32_t num_items;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
|
||||
if ((label_len - 8) != num_items) {
|
||||
RETURN_STATUS_UNEXPECTED("Wrong number of labels!");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 8) == num_items, "Wrong number of labels!");
|
||||
*num_labels = num_items;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -330,6 +319,9 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la
|
|||
}
|
||||
|
||||
Status MnistOp::ParseMnistData() {
|
||||
// MNIST contains 4 files, idx3 are image files, idx 1 are labels
|
||||
// training files contain 60K examples and testing files contain 10K examples
|
||||
// t10k-images-idx3-ubyte t10k-labels-idx1-ubyte train-images-idx3-ubyte train-labels-idx1-ubyte
|
||||
for (size_t i = 0; i < image_names_.size(); ++i) {
|
||||
std::ifstream image_reader, label_reader;
|
||||
image_reader.open(image_names_[i], std::ios::binary);
|
||||
|
@ -354,18 +346,22 @@ Status MnistOp::ParseMnistData() {
|
|||
Status MnistOp::WalkAllFiles() {
|
||||
const std::string kImageExtension = "idx3-ubyte";
|
||||
const std::string kLabelExtension = "idx1-ubyte";
|
||||
const std::string train_prefix = "train";
|
||||
const std::string test_prefix = "t10k";
|
||||
|
||||
Path dir(folder_path_);
|
||||
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
|
||||
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
|
||||
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);
|
||||
if (dir_it != nullptr) {
|
||||
while (dir_it->hasNext()) {
|
||||
Path file = dir_it->next();
|
||||
std::string filename = file.toString();
|
||||
if (filename.find(kImageExtension) != std::string::npos) {
|
||||
image_names_.push_back(filename);
|
||||
std::string filename = file.Basename();
|
||||
if (filename.find(prefix + "-images-" + kImageExtension) != std::string::npos) {
|
||||
image_names_.push_back(file.toString());
|
||||
MS_LOG(INFO) << "Mnist operator found image file at " << filename << ".";
|
||||
} else if (filename.find(kLabelExtension) != std::string::npos) {
|
||||
label_names_.push_back(filename);
|
||||
} else if (filename.find(prefix + "-labels-" + kLabelExtension) != std::string::npos) {
|
||||
label_names_.push_back(file.toString());
|
||||
MS_LOG(INFO) << "Mnist Operator found label file at " << filename << ".";
|
||||
}
|
||||
}
|
||||
|
@ -376,9 +372,7 @@ Status MnistOp::WalkAllFiles() {
|
|||
std::sort(image_names_.begin(), image_names_.end());
|
||||
std::sort(label_names_.begin(), label_names_.end());
|
||||
|
||||
if (image_names_.size() != label_names_.size()) {
|
||||
RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(), "num of idx3 files != num of idx1 files");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -397,11 +391,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
|
||||
std::shared_ptr<MnistOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).Build(&op));
|
||||
|
||||
RETURN_IF_NOT_OK(op->WalkAllFiles());
|
||||
|
||||
|
|
|
@ -47,8 +47,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
class Builder {
|
||||
public:
|
||||
// Constructor for Builder class of MnistOp
|
||||
// @param uint32_t numWrks - number of parallel workers
|
||||
// @param dir - directory folder got ImageNetFolder
|
||||
Builder();
|
||||
|
||||
// Destructor.
|
||||
|
@ -87,13 +85,20 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string & dir
|
||||
// @param const std::string &dir
|
||||
// @return
|
||||
Builder &SetDir(const std::string &dir) {
|
||||
builder_dir_ = dir;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string &usage
|
||||
// @return
|
||||
Builder &SetUsage(const std::string &usage) {
|
||||
builder_usage_ = usage;
|
||||
return *this;
|
||||
}
|
||||
// Check validity of input args
|
||||
// @return - The error code return
|
||||
Status SanityCheck();
|
||||
|
@ -105,6 +110,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
std::string builder_dir_;
|
||||
std::string builder_usage_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int32_t builder_op_connector_size_;
|
||||
|
@ -113,14 +119,15 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
};
|
||||
|
||||
// Constructor
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
|
||||
// @param int32_t num_workers - number of workers reading images in parallel
|
||||
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
|
||||
// @param std::string folder_path - dir directory of mnist
|
||||
// @param int32_t queue_size - connector queue size
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
|
||||
MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Destructor.
|
||||
~MnistOp() = default;
|
||||
|
@ -150,7 +157,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param dir path to the MNIST directory
|
||||
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
|
||||
// @return
|
||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
/// \param[in] p Pointer to the NodePass to be accepted
|
||||
|
@ -241,6 +248,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
WaitPost wp_;
|
||||
std::string folder_path_; // directory of image folder
|
||||
int32_t rows_per_buffer_;
|
||||
const std::string usage_; // can only be either "train" or "test"
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::vector<MnistLabelPair> image_label_pairs_;
|
||||
std::vector<std::string> image_names_;
|
||||
|
|
|
@ -18,14 +18,15 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "./tinyxml2.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
using tinyxml2::XMLDocument;
|
||||
using tinyxml2::XMLElement;
|
||||
|
@ -81,7 +82,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
|
|||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
}
|
||||
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_,
|
||||
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_usage_, builder_dir_, builder_labels_to_read_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_,
|
||||
builder_decode_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
|
@ -103,7 +104,7 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
|
|||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
task_type_(task_type),
|
||||
task_mode_(task_mode),
|
||||
usage_(task_mode),
|
||||
folder_path_(folder_path),
|
||||
class_index_(class_index),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -251,10 +252,9 @@ Status VOCOp::WorkerEntry(int32_t worker_id) {
|
|||
Status VOCOp::ParseImageIds() {
|
||||
std::string image_sets_file;
|
||||
if (task_type_ == TaskType::Segmentation) {
|
||||
image_sets_file =
|
||||
folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension);
|
||||
image_sets_file = folder_path_ + std::string(kImageSetsSegmentation) + usage_ + std::string(kImageSetsExtension);
|
||||
} else if (task_type_ == TaskType::Detection) {
|
||||
image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension);
|
||||
image_sets_file = folder_path_ + std::string(kImageSetsMain) + usage_ + std::string(kImageSetsExtension);
|
||||
}
|
||||
std::ifstream in_file;
|
||||
in_file.open(image_sets_file);
|
||||
|
@ -431,13 +431,13 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
|
|||
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
*count = static_cast<int64_t>(op->image_ids_.size());
|
||||
} else if (task_type == "Segmentation") {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
*count = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
|
@ -458,7 +458,7 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
|
|||
} else {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
for (const auto label : op->label_index_) {
|
||||
|
|
|
@ -73,7 +73,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Setter method.
|
||||
// @param const std::string & task_type
|
||||
// @param const std::string &task_type
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetTask(const std::string &task_type) {
|
||||
if (task_type == "Segmentation") {
|
||||
|
@ -85,10 +85,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Setter method.
|
||||
// @param const std::string & task_mode
|
||||
// @param const std::string &usage
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetMode(const std::string &task_mode) {
|
||||
builder_task_mode_ = task_mode;
|
||||
Builder &SetUsage(const std::string &usage) {
|
||||
builder_usage_ = usage;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
bool builder_decode_;
|
||||
std::string builder_dir_;
|
||||
TaskType builder_task_type_;
|
||||
std::string builder_task_mode_;
|
||||
std::string builder_usage_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
|
@ -279,7 +279,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
int64_t buf_cnt_;
|
||||
std::string folder_path_;
|
||||
TaskType task_type_;
|
||||
std::string task_mode_;
|
||||
std::string usage_;
|
||||
int32_t rows_per_buffer_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
|
|
|
@ -111,34 +111,36 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
|
|||
|
||||
/// \brief Function to create a CelebADataset
|
||||
/// \notes The generated dataset has two columns ['image', 'attr'].
|
||||
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
|
||||
/// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
|
||||
/// \param[in] usage One of "all", "train", "valid" or "test".
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage = "all",
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false,
|
||||
const std::set<std::string> &extensions = {});
|
||||
|
||||
/// \brief Function to create a Cifar10 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \notes The generated dataset has two columns ["image", "label"]
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] usage of CIFAR10, can be "train", "test" or "all"
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage = std::string(),
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
|
||||
|
||||
/// \brief Function to create a Cifar100 Dataset
|
||||
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
|
||||
/// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"]
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] usage of CIFAR100, can be "train", "test" or "all"
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage = std::string(),
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
|
||||
|
||||
/// \brief Function to create a CLUEDataset
|
||||
|
@ -212,7 +214,7 @@ std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, c
|
|||
/// \brief Function to create an ImageFolderDataset
|
||||
/// \notes A source dataset that reads images from a tree of directories
|
||||
/// All images within one folder have the same label
|
||||
/// The generated dataset has two columns ['image', 'label']
|
||||
/// The generated dataset has two columns ["image", "label"]
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] decode A flag to decode in ImageFolder
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
|
@ -227,7 +229,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to create a ManifestDataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \notes The generated dataset has two columns ["image", "label"]
|
||||
/// \param[in] dataset_file The dataset file to be read
|
||||
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
|
@ -243,12 +245,13 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
|
|||
#endif
|
||||
|
||||
/// \brief Function to create a MnistDataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \notes The generated dataset has two columns ["image", "label"]
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] usage of MNIST, can be "train", "test" or "all"
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \return Shared pointer to the current MnistDataset
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = std::string(),
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
|
||||
|
||||
/// \brief Function to create a ConcatDataset
|
||||
|
@ -404,14 +407,14 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase
|
|||
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
|
||||
/// \param[in] mode Set the data list txt file to be readed
|
||||
/// \param[in] usage The type of data list text file to be read
|
||||
/// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task
|
||||
/// \param[in] decode Decode the images after reading
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
|
||||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
|
||||
const std::string &mode = "train",
|
||||
const std::string &usage = "train",
|
||||
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
|
||||
#endif
|
||||
|
@ -702,9 +705,8 @@ class AlbumDataset : public Dataset {
|
|||
class CelebADataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::set<std::string> &extensions);
|
||||
CelebADataset(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const bool &decode, const std::set<std::string> &extensions);
|
||||
|
||||
/// \brief Destructor
|
||||
~CelebADataset() = default;
|
||||
|
@ -719,7 +721,7 @@ class CelebADataset : public Dataset {
|
|||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string dataset_type_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
|
@ -730,7 +732,7 @@ class CelebADataset : public Dataset {
|
|||
class Cifar10Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
|
||||
Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Dataset() = default;
|
||||
|
@ -745,13 +747,14 @@ class Cifar10Dataset : public Dataset {
|
|||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class Cifar100Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
|
||||
Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Dataset() = default;
|
||||
|
@ -766,6 +769,7 @@ class Cifar100Dataset : public Dataset {
|
|||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
|
@ -831,7 +835,7 @@ class CocoDataset : public Dataset {
|
|||
enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
|
||||
|
||||
/// \brief Base class of CSV Record
|
||||
struct CsvBase {
|
||||
class CsvBase {
|
||||
public:
|
||||
CsvBase() = default;
|
||||
explicit CsvBase(CsvType t) : type(t) {}
|
||||
|
@ -936,7 +940,7 @@ class ManifestDataset : public Dataset {
|
|||
class MnistDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler);
|
||||
MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~MnistDataset() = default;
|
||||
|
@ -951,6 +955,7 @@ class MnistDataset : public Dataset {
|
|||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
|
@ -1087,7 +1092,7 @@ class TFRecordDataset : public Dataset {
|
|||
class VOCDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
|
@ -1110,7 +1115,7 @@ class VOCDataset : public Dataset {
|
|||
const std::string kColumnTruncate = "truncate";
|
||||
std::string dataset_dir_;
|
||||
std::string task_;
|
||||
std::string mode_;
|
||||
std::string usage_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
|
|
|
@ -132,6 +132,12 @@ def check_valid_detype(type_):
|
|||
return True
|
||||
|
||||
|
||||
def check_valid_str(value, valid_strings, arg_name=""):
|
||||
type_check(value, (str,), arg_name)
|
||||
if value not in valid_strings:
|
||||
raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings)))
|
||||
|
||||
|
||||
def check_columns(columns, name):
|
||||
"""
|
||||
Validate strings in column_names.
|
||||
|
|
|
@ -2877,6 +2877,9 @@ class MnistDataset(MappableDataset):
|
|||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 60,000
|
||||
train samples, "test" will read from 10,000 test samples, "all" will read from all 70,000 samples.
|
||||
(default=None, all samples)
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
|
@ -2906,11 +2909,12 @@ class MnistDataset(MappableDataset):
|
|||
"""
|
||||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = usage
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.shuffle_level = shuffle
|
||||
|
@ -2920,6 +2924,7 @@ class MnistDataset(MappableDataset):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_dir"] = self.dataset_dir
|
||||
args["usage"] = self.usage
|
||||
args["num_samples"] = self.num_samples
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["sampler"] = self.sampler
|
||||
|
@ -2935,7 +2940,7 @@ class MnistDataset(MappableDataset):
|
|||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir)
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
|
@ -3913,6 +3918,9 @@ class Cifar10Dataset(MappableDataset):
|
|||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
|
||||
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
|
||||
(default=None, all samples)
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
|
@ -3946,11 +3954,12 @@ class Cifar10Dataset(MappableDataset):
|
|||
"""
|
||||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = usage
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
|
@ -3960,6 +3969,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_dir"] = self.dataset_dir
|
||||
args["usage"] = self.usage
|
||||
args["num_samples"] = self.num_samples
|
||||
args["sampler"] = self.sampler
|
||||
args["num_shards"] = self.num_shards
|
||||
|
@ -3975,7 +3985,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, True)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -4051,6 +4061,9 @@ class Cifar100Dataset(MappableDataset):
|
|||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
|
||||
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
|
||||
(default=None, all samples)
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
|
@ -4082,11 +4095,12 @@ class Cifar100Dataset(MappableDataset):
|
|||
"""
|
||||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = usage
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
|
@ -4096,6 +4110,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_dir"] = self.dataset_dir
|
||||
args["usage"] = self.usage
|
||||
args["num_samples"] = self.num_samples
|
||||
args["sampler"] = self.sampler
|
||||
args["num_shards"] = self.num_shards
|
||||
|
@ -4111,7 +4126,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, False)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -4467,7 +4482,7 @@ class VOCDataset(MappableDataset):
|
|||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
|
||||
(default="Segmentation").
|
||||
mode (str): Set the data list txt file to be readed (default="train").
|
||||
usage (str): The type of data list text file to be read (default="train").
|
||||
class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in
|
||||
"Detection" task (default=None, the folder names will be sorted alphabetically and each
|
||||
class will be given a unique index starting from 0).
|
||||
|
@ -4502,24 +4517,24 @@ class VOCDataset(MappableDataset):
|
|||
>>> import mindspore.dataset as ds
|
||||
>>> dataset_dir = "/path/to/voc_dataset_directory"
|
||||
>>> # 1) read VOC data for segmenatation train
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train")
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", usage="train")
|
||||
>>> # 2) read VOC data for detection train
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train")
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train")
|
||||
>>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8)
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", num_parallel_workers=8)
|
||||
>>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", decode=True, shuffle=False)
|
||||
>>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
|
||||
>>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
|
||||
"""
|
||||
|
||||
@check_vocdataset
|
||||
def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None,
|
||||
def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.task = task
|
||||
self.mode = mode
|
||||
self.usage = usage
|
||||
self.class_indexing = class_indexing
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
|
@ -4532,7 +4547,7 @@ class VOCDataset(MappableDataset):
|
|||
args = super().get_args()
|
||||
args["dataset_dir"] = self.dataset_dir
|
||||
args["task"] = self.task
|
||||
args["mode"] = self.mode
|
||||
args["usage"] = self.usage
|
||||
args["class_indexing"] = self.class_indexing
|
||||
args["num_samples"] = self.num_samples
|
||||
args["sampler"] = self.sampler
|
||||
|
@ -4560,7 +4575,7 @@ class VOCDataset(MappableDataset):
|
|||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.usage, class_indexing, num_samples)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -4584,7 +4599,7 @@ class VOCDataset(MappableDataset):
|
|||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing)
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.usage, class_indexing)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -4824,7 +4839,7 @@ class CelebADataset(MappableDataset):
|
|||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
|
||||
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
|
||||
dataset_type (str): one of 'all', 'train', 'valid' or 'test'.
|
||||
usage (str): one of 'all', 'train', 'valid' or 'test'.
|
||||
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
|
||||
decode (bool, optional): decode the images after reading (default=False).
|
||||
extensions (list[str], optional): List of file extensions to be
|
||||
|
@ -4838,8 +4853,8 @@ class CelebADataset(MappableDataset):
|
|||
"""
|
||||
|
||||
@check_celebadataset
|
||||
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all',
|
||||
sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None):
|
||||
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
|
||||
extensions=None, num_samples=None, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
|
@ -4847,7 +4862,7 @@ class CelebADataset(MappableDataset):
|
|||
self.decode = decode
|
||||
self.extensions = extensions
|
||||
self.num_samples = num_samples
|
||||
self.dataset_type = dataset_type
|
||||
self.usage = usage
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle_level = shuffle
|
||||
|
@ -4860,7 +4875,7 @@ class CelebADataset(MappableDataset):
|
|||
args["decode"] = self.decode
|
||||
args["extensions"] = self.extensions
|
||||
args["num_samples"] = self.num_samples
|
||||
args["dataset_type"] = self.dataset_type
|
||||
args["usage"] = self.usage
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
|
|
@ -273,7 +273,7 @@ def create_node(node):
|
|||
|
||||
elif dataset_op == 'MnistDataset':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'MindDataset':
|
||||
|
@ -296,12 +296,12 @@ def create_node(node):
|
|||
|
||||
elif dataset_op == 'Cifar10Dataset':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'Cifar100Dataset':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'VOCDataset':
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.dataset.callback import DSCallback
|
|||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
check_columns, check_pos_int32
|
||||
check_columns, check_pos_int32, check_valid_str
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -74,6 +74,10 @@ def check_mnist_cifar_dataset(method):
|
|||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
|
@ -154,15 +158,15 @@ def check_vocdataset(method):
|
|||
task = param_dict.get('task')
|
||||
type_check(task, (str,), "task")
|
||||
|
||||
mode = param_dict.get('mode')
|
||||
type_check(mode, (str,), "mode")
|
||||
usage = param_dict.get('usage')
|
||||
type_check(usage, (str,), "usage")
|
||||
|
||||
if task == "Segmentation":
|
||||
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
|
||||
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
|
||||
if param_dict.get('class_indexing') is not None:
|
||||
raise ValueError("class_indexing is invalid in Segmentation task")
|
||||
elif task == "Detection":
|
||||
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
|
||||
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
|
||||
else:
|
||||
raise ValueError("Invalid task : " + task)
|
||||
|
||||
|
@ -235,9 +239,9 @@ def check_celebadataset(method):
|
|||
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
||||
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
||||
|
||||
dataset_type = param_dict.get('dataset_type')
|
||||
if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
|
||||
raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
|
||||
raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.")
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ SET(DE_UT_SRCS
|
|||
common/cvop_common.cc
|
||||
common/bboxop_common.cc
|
||||
auto_contrast_op_test.cc
|
||||
album_op_test.cc
|
||||
album_op_test.cc
|
||||
batch_op_test.cc
|
||||
bit_functions_test.cc
|
||||
storage_container_test.cc
|
||||
|
@ -62,8 +62,8 @@ SET(DE_UT_SRCS
|
|||
rescale_op_test.cc
|
||||
resize_op_test.cc
|
||||
resize_with_bbox_op_test.cc
|
||||
rgba_to_bgr_op_test.cc
|
||||
rgba_to_rgb_op_test.cc
|
||||
rgba_to_bgr_op_test.cc
|
||||
rgba_to_rgb_op_test.cc
|
||||
schema_test.cc
|
||||
skip_op_test.cc
|
||||
shuffle_op_test.cc
|
||||
|
|
|
@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
@ -45,10 +45,10 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
@ -62,7 +62,7 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
|||
|
||||
// Create a Cifar100 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
@ -96,7 +96,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1.";
|
||||
|
||||
// Create a Cifar100 Dataset
|
||||
std::shared_ptr<Dataset> ds = Cifar100("", RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar100("", std::string(), RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1.";
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
std::shared_ptr<Dataset> ds = Cifar10("", RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10("", std::string(), RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetWithNullSampler) {
|
|||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, nullptr);
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), nullptr);
|
||||
// Expect failure: sampler can not be nullptr
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithNullSampler) {
|
|||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, nullptr);
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), nullptr);
|
||||
// Expect failure: sampler can not be nullptr
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithWrongSampler) {
|
|||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, -10));
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, -10));
|
||||
// Expect failure: sampler is not construnced correctly
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestIteratorEmptyColumn) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorEmptyColumn.";
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 5));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Rename operation on ds
|
||||
|
@ -64,7 +64,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn.";
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -103,7 +103,7 @@ TEST_F(MindDataTestPipeline, TestIteratorReOrder) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorReOrder.";
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, SequentialSampler(false, 4));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), SequentialSampler(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Take operation on ds
|
||||
|
@ -160,9 +160,8 @@ TEST_F(MindDataTestPipeline, TestIteratorTwoColumns) {
|
|||
// Iterate the dataset and get each row
|
||||
std::vector<std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}),
|
||||
TensorShape({173673}), TensorShape({1, 4}),
|
||||
TensorShape({147025}), TensorShape({1, 4}),
|
||||
std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}), TensorShape({173673}),
|
||||
TensorShape({1, 4}), TensorShape({147025}), TensorShape({1, 4}),
|
||||
TensorShape({211653}), TensorShape({1, 4})};
|
||||
|
||||
uint64_t i = 0;
|
||||
|
@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn.";
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Pass wrong column name
|
||||
|
|
|
@ -40,7 +40,7 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
@ -82,7 +82,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess1) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -118,13 +118,12 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
std::map<std::string, std::pair<mindspore::dataset::TensorShape, std::shared_ptr<Tensor>>> pad_info;
|
||||
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3},
|
||||
&BucketBatchTestFunction, pad_info, true, true);
|
||||
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3}, &BucketBatchTestFunction, pad_info, true, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
@ -157,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail1) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -172,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail2) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail3) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -202,7 +201,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail4) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -217,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail5) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -232,7 +231,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail6) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, -2, 3});
|
||||
|
@ -246,7 +245,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail7) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a BucketBatchByLength operation on ds
|
||||
|
@ -313,7 +312,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) {
|
|||
// Create a Cifar10 Dataset
|
||||
// Column names: {"image", "label"}
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9));
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
|
@ -365,7 +364,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess2) {
|
|||
// Create a Cifar10 Dataset
|
||||
// Column names: {"image", "label"}
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9));
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
|
@ -704,11 +703,11 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRepeatDefault) {
|
||||
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault.";
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatDefault.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
@ -723,21 +722,21 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
|
|||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr <Iterator> iter = ds->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// iterate over the dataset and get each row
|
||||
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
uint64_t i = 0;
|
||||
while (row.size()!= 0) {
|
||||
while (row.size() != 0) {
|
||||
// manually stop
|
||||
if (i == 100) {
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
|
@ -747,11 +746,11 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRepeatOne) {
|
||||
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne.";
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatOne.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
@ -766,17 +765,17 @@ TEST_F(MindDataTestPipeline, TestRepeatOne) {
|
|||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr <Iterator> iter = ds->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// iterate over the dataset and get each row
|
||||
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
uint64_t i = 0;
|
||||
while (row.size()!= 0) {
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
|
@ -1013,7 +1012,7 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 20));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
@ -1060,7 +1059,6 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestZipFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail.";
|
||||
// We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns
|
||||
|
@ -1128,7 +1126,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) {
|
|||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
|
|
|
@ -43,10 +43,11 @@ TEST_F(MindDataTestPipeline, TestCelebADataset) {
|
|||
|
||||
// Check if CelebAOp read correct images/attr
|
||||
std::string expect_file[] = {"1.JPEG", "2.jpg"};
|
||||
std::vector<std::vector<uint32_t>> expect_attr_vector =
|
||||
{{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0,
|
||||
1, 0, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
|
||||
1, 0, 0, 0, 0, 0, 0, 0, 1}};
|
||||
std::vector<std::vector<uint32_t>> expect_attr_vector = {
|
||||
{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1},
|
||||
{0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1}};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto image = row["image"];
|
||||
|
@ -132,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) {
|
|||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir.";
|
||||
|
||||
// Create a Mnist Dataset
|
||||
std::shared_ptr<Dataset> ds = Mnist("", RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Mnist("", std::string(), RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
|
@ -141,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithNullSampler) {
|
|||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, nullptr);
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), nullptr);
|
||||
// Expect failure: sampler can not be nullptr
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
|
|||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
int number_of_classes = 10;
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
|
@ -38,7 +38,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
|
|||
EXPECT_NE(hwc_to_chw, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({hwc_to_chw},{"image"});
|
||||
ds = ds->Map({hwc_to_chw}, {"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -51,10 +51,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
|
|||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0);
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op =
|
||||
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0);
|
||||
EXPECT_NE(cutmix_batch_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
|
@ -77,10 +78,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
|
|||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Label shape: " << label->shape();
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1]
|
||||
&& 32 == image->shape()[2] && 32 == image->shape()[3], true);
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] &&
|
||||
32 == image->shape()[2] && 32 == image->shape()[3],
|
||||
true);
|
||||
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
|
||||
number_of_classes == label->shape()[1], true);
|
||||
number_of_classes == label->shape()[1],
|
||||
true);
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
|
@ -95,7 +98,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
|
|||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
int number_of_classes = 10;
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -108,7 +111,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
|
|||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC);
|
||||
|
@ -134,10 +137,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
|
|||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Label shape: " << label->shape();
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1]
|
||||
&& 32 == image->shape()[2] && 3 == image->shape()[3], true);
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] &&
|
||||
32 == image->shape()[2] && 3 == image->shape()[3],
|
||||
true);
|
||||
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
|
||||
number_of_classes == label->shape()[1], true);
|
||||
number_of_classes == label->shape()[1],
|
||||
true);
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
|
@ -151,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
|
|||
// Must fail because alpha can't be negative
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -164,10 +169,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
|
|||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5);
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op =
|
||||
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
}
|
||||
|
||||
|
@ -175,7 +181,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
|
|||
// Must fail because prob can't be negative
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -188,20 +194,19 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
|
|||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
|
||||
1, -0.5);
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op =
|
||||
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
|
||||
// Must fail because alpha can't be zero
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -214,11 +219,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
|
|||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
|
||||
0.0, 0.5);
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op =
|
||||
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 0.0, 0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
}
|
||||
|
||||
|
@ -371,7 +376,7 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) {
|
|||
TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -395,7 +400,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
|
|||
// This should fail because alpha can't be zero
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -418,7 +423,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
|
|||
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -467,7 +472,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
|
|||
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) {
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
|
@ -871,8 +876,7 @@ TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) {
|
|||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> posterize =
|
||||
vision::RandomPosterize({1, 4});
|
||||
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize({1, 4});
|
||||
EXPECT_NE(posterize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
|
@ -1114,7 +1118,7 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) {
|
|||
TEST_F(MindDataTestPipeline, TestUniformAugWithOps) {
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20));
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, "", RandomSampler(false, 20));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
|
|
@ -42,9 +42,13 @@ std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, i
|
|||
bool decode = false, const std::string &dataset_type="all") {
|
||||
std::shared_ptr<CelebAOp> so;
|
||||
CelebAOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer)
|
||||
.SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode)
|
||||
.SetDatasetType(dataset_type).Build(&so);
|
||||
Status rc = builder.SetNumWorkers(num_workers)
|
||||
.SetCelebADir(dir)
|
||||
.SetRowsPerBuffer(rows_per_buffer)
|
||||
.SetOpConnectorSize(queue_size)
|
||||
.SetSampler(std::move(sampler))
|
||||
.SetDecode(decode)
|
||||
.SetUsage(dataset_type).Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
|
|
|
@ -63,9 +63,7 @@ TEST_F(MindDataTestVOCOp, TestVOCDetection) {
|
|||
std::string task_mode("train");
|
||||
std::shared_ptr<VOCOp> my_voc_op;
|
||||
VOCOp::Builder builder;
|
||||
Status rc = builder.SetDir(dataset_path)
|
||||
.SetTask(task_type)
|
||||
.SetMode(task_mode)
|
||||
Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
|
||||
.Build(&my_voc_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -116,9 +114,7 @@ TEST_F(MindDataTestVOCOp, TestVOCSegmentation) {
|
|||
std::string task_mode("train");
|
||||
std::shared_ptr<VOCOp> my_voc_op;
|
||||
VOCOp::Builder builder;
|
||||
Status rc = builder.SetDir(dataset_path)
|
||||
.SetTask(task_type)
|
||||
.SetMode(task_mode)
|
||||
Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
|
||||
.Build(&my_voc_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -173,9 +169,8 @@ TEST_F(MindDataTestVOCOp, TestVOCClassIndex) {
|
|||
class_index["train"] = 5;
|
||||
std::shared_ptr<VOCOp> my_voc_op;
|
||||
VOCOp::Builder builder;
|
||||
Status rc = builder.SetDir(dataset_path)
|
||||
.SetTask(task_type)
|
||||
.SetMode(task_mode)
|
||||
Status rc =
|
||||
builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
|
||||
.SetClassIndex(class_index)
|
||||
.Build(&my_voc_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
|
|
@ -42,8 +42,8 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False):
|
|||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# Ratio is set to 1 to apply rotation on all bounding boxes.
|
||||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
|
||||
|
@ -81,8 +81,8 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False):
|
|||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes.
|
||||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9)
|
||||
|
@ -120,8 +120,8 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False):
|
|||
original_seed = config_get_set_seed(1)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
|
||||
|
||||
|
@ -188,8 +188,8 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False):
|
|||
original_seed = config_get_set_seed(1)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1)
|
||||
|
||||
|
@ -232,7 +232,7 @@ def test_bounding_box_augment_invalid_ratio_c():
|
|||
"""
|
||||
logger.info("test_bounding_box_augment_invalid_ratio_c")
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
# ratio range is from 0 - 1
|
||||
|
@ -256,13 +256,13 @@ def test_bounding_box_augment_invalid_bounds_c():
|
|||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1),
|
||||
1)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ DATA_DIR = "../data/dataset/testCelebAData/"
|
|||
|
||||
|
||||
def test_celeba_dataset_label():
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||
expect_labels = [
|
||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||
0, 0, 1],
|
||||
|
@ -85,11 +85,13 @@ def test_celeba_dataset_distribute():
|
|||
count = count + 1
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_celeba_get_dataset_size():
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_celeba_dataset_label()
|
||||
test_celeba_dataset_op()
|
||||
|
|
|
@ -392,6 +392,59 @@ def test_cifar100_visualize(plot=False):
|
|||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_cifar_usage():
|
||||
"""
|
||||
test usage of cifar
|
||||
"""
|
||||
logger.info("Test Cifar100Dataset usage flag")
|
||||
|
||||
# flag, if True, test cifar10 else test cifar100
|
||||
def test_config(usage, flag=True, cifar_path=None):
|
||||
if cifar_path is None:
|
||||
cifar_path = DATA_DIR_10 if flag else DATA_DIR_100
|
||||
try:
|
||||
data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
# test the usage of CIFAR100
|
||||
assert test_config("train") == 10000
|
||||
assert test_config("all") == 10000
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"])
|
||||
assert "no valid data matching the dataset API Cifar10Dataset" in test_config("test")
|
||||
|
||||
# test the usage of CIFAR10
|
||||
assert test_config("test", False) == 10000
|
||||
assert test_config("all", False) == 10000
|
||||
assert "no valid data matching the dataset API Cifar100Dataset" in test_config("train", False)
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False)
|
||||
|
||||
# change this directory to the folder that contains all cifar10 files
|
||||
all_cifar10 = None
|
||||
if all_cifar10 is not None:
|
||||
assert test_config("train", True, all_cifar10) == 50000
|
||||
assert test_config("test", True, all_cifar10) == 10000
|
||||
assert test_config("all", True, all_cifar10) == 60000
|
||||
assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000
|
||||
assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000
|
||||
assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000
|
||||
|
||||
# change this directory to the folder that contains all cifar100 files
|
||||
all_cifar100 = None
|
||||
if all_cifar100 is not None:
|
||||
assert test_config("train", False, all_cifar100) == 50000
|
||||
assert test_config("test", False, all_cifar100) == 10000
|
||||
assert test_config("all", False, all_cifar100) == 60000
|
||||
assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000
|
||||
assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000
|
||||
assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cifar10_content_check()
|
||||
test_cifar10_basic()
|
||||
|
@ -405,3 +458,5 @@ if __name__ == '__main__':
|
|||
test_cifar100_pk_sampler()
|
||||
test_cifar100_exception()
|
||||
test_cifar100_visualize(plot=False)
|
||||
|
||||
test_cifar_usage()
|
||||
|
|
|
@ -58,6 +58,14 @@ def test_mnist_dataset_size():
|
|||
ds_total = ds.MnistDataset(MNIST_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get dataset_size with the usage arg
|
||||
test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 10000
|
||||
train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert all_size == 10000
|
||||
|
||||
ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
|
||||
|
@ -86,6 +94,14 @@ def test_cifar10_dataset_size():
|
|||
ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 10000
|
||||
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 0
|
||||
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert all_size == 10000
|
||||
|
||||
ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
|
||||
|
@ -103,6 +119,14 @@ def test_cifar100_dataset_size():
|
|||
ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 10000
|
||||
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert all_size == 10000
|
||||
|
||||
ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
|
||||
|
@ -111,3 +135,12 @@ def test_cifar100_dataset_size():
|
|||
|
||||
ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_imagenet_rawdata_dataset_size()
|
||||
test_imagenet_tf_file_dataset_size()
|
||||
test_mnist_dataset_size()
|
||||
test_manifest_dataset_size()
|
||||
test_cifar10_dataset_size()
|
||||
test_cifar100_dataset_size()
|
||||
|
|
|
@ -229,6 +229,41 @@ def test_mnist_visualize(plot=False):
|
|||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_mnist_usage():
|
||||
"""
|
||||
Validate MnistDataset image readings
|
||||
"""
|
||||
logger.info("Test MnistDataset usage flag")
|
||||
|
||||
def test_config(usage, mnist_path=None):
|
||||
mnist_path = DATA_DIR if mnist_path is None else mnist_path
|
||||
try:
|
||||
data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("test") == 10000
|
||||
assert test_config("all") == 10000
|
||||
assert " no valid data matching the dataset API MnistDataset" in test_config("train")
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"])
|
||||
|
||||
# change this directory to the folder that contains all mnist files
|
||||
all_files_path = None
|
||||
# the following tests on the entire datasets
|
||||
if all_files_path is not None:
|
||||
assert test_config("train", all_files_path) == 60000
|
||||
assert test_config("test", all_files_path) == 10000
|
||||
assert test_config("all", all_files_path) == 70000
|
||||
assert ds.MnistDataset(all_files_path, usage="train").get_dataset_size() == 60000
|
||||
assert ds.MnistDataset(all_files_path, usage="test").get_dataset_size() == 10000
|
||||
assert ds.MnistDataset(all_files_path, usage="all").get_dataset_size() == 70000
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mnist_content_check()
|
||||
test_mnist_basic()
|
||||
|
@ -236,3 +271,4 @@ if __name__ == '__main__':
|
|||
test_mnist_sequential_sampler()
|
||||
test_mnist_exception()
|
||||
test_mnist_visualize(plot=True)
|
||||
test_mnist_usage()
|
||||
|
|
|
@ -21,7 +21,7 @@ TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
|
|||
|
||||
|
||||
def test_voc_segmentation():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
num = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1):
|
||||
assert item["image"].shape[0] == IMAGE_SHAPE[num]
|
||||
|
@ -31,7 +31,7 @@ def test_voc_segmentation():
|
|||
|
||||
|
||||
def test_voc_detection():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
num = 0
|
||||
count = [0, 0, 0, 0, 0, 0]
|
||||
for item in data1.create_dict_iterator(num_epochs=1):
|
||||
|
@ -45,7 +45,7 @@ def test_voc_detection():
|
|||
|
||||
def test_voc_class_index():
|
||||
class_index = {'car': 0, 'cat': 1, 'train': 5}
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", class_indexing=class_index, decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
|
||||
class_index1 = data1.get_class_indexing()
|
||||
assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5})
|
||||
data1 = data1.shuffle(4)
|
||||
|
@ -63,7 +63,7 @@ def test_voc_class_index():
|
|||
|
||||
|
||||
def test_voc_get_class_indexing():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
|
||||
class_index1 = data1.get_class_indexing()
|
||||
assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
|
||||
data1 = data1.shuffle(4)
|
||||
|
@ -81,7 +81,7 @@ def test_voc_get_class_indexing():
|
|||
|
||||
|
||||
def test_case_0():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
|
||||
|
||||
resize_op = vision.Resize((224, 224))
|
||||
|
||||
|
@ -99,7 +99,7 @@ def test_case_0():
|
|||
|
||||
|
||||
def test_case_1():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
|
||||
|
||||
resize_op = vision.Resize((224, 224))
|
||||
|
||||
|
@ -116,7 +116,7 @@ def test_case_1():
|
|||
|
||||
|
||||
def test_case_2():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
|
||||
sizes = [0.5, 0.5]
|
||||
randomize = False
|
||||
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)
|
||||
|
@ -134,7 +134,7 @@ def test_case_2():
|
|||
|
||||
def test_voc_exception():
|
||||
try:
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", usage="train", decode=True)
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
@ -142,7 +142,7 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
try:
|
||||
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", class_indexing={"cat": 0}, decode=True)
|
||||
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", class_indexing={"cat": 0}, decode=True)
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
@ -150,7 +150,7 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
try:
|
||||
data3 = ds.VOCDataset(DATA_DIR, task="Detection", mode="notexist", decode=True)
|
||||
data3 = ds.VOCDataset(DATA_DIR, task="Detection", usage="notexist", decode=True)
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
@ -158,7 +158,7 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
try:
|
||||
data4 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnotexist", decode=True)
|
||||
data4 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnotexist", decode=True)
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
@ -166,7 +166,7 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
try:
|
||||
data5 = ds.VOCDataset(DATA_DIR, task="Detection", mode="invalidxml", decode=True)
|
||||
data5 = ds.VOCDataset(DATA_DIR, task="Detection", usage="invalidxml", decode=True)
|
||||
for _ in data5.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
@ -174,7 +174,7 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
try:
|
||||
data6 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnoobject", decode=True)
|
||||
data6 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnoobject", decode=True)
|
||||
for _ in data6.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
|
|
|
@ -35,6 +35,7 @@ def diff_mse(in1, in2):
|
|||
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
|
||||
return mse * 100
|
||||
|
||||
|
||||
def test_cifar10():
|
||||
"""
|
||||
dataset parameter
|
||||
|
@ -45,7 +46,7 @@ def test_cifar10():
|
|||
batch_size = 32
|
||||
limit_dataset = 100
|
||||
# apply dataset operations
|
||||
data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
|
||||
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
|
||||
data1 = data1.repeat(num_repeat)
|
||||
data1 = data1.batch(batch_size, True)
|
||||
num_epoch = 5
|
||||
|
@ -139,6 +140,7 @@ def test_generator_dict_0():
|
|||
np.testing.assert_array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
|
||||
|
||||
def test_generator_dict_1():
|
||||
"""
|
||||
test generator dict 1
|
||||
|
@ -158,6 +160,7 @@ def test_generator_dict_1():
|
|||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
|
||||
def test_generator_dict_2():
|
||||
"""
|
||||
test generator dict 2
|
||||
|
@ -180,6 +183,7 @@ def test_generator_dict_2():
|
|||
assert item1
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_dict_3():
|
||||
"""
|
||||
test generator dict 3
|
||||
|
@ -226,6 +230,7 @@ def test_generator_dict_4():
|
|||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_4_1():
|
||||
"""
|
||||
test generator dict 4_1
|
||||
|
@ -249,6 +254,7 @@ def test_generator_dict_4_1():
|
|||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_4_2():
|
||||
"""
|
||||
test generator dict 4_2
|
||||
|
@ -274,6 +280,7 @@ def test_generator_dict_4_2():
|
|||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_5():
|
||||
"""
|
||||
test generator dict 5
|
||||
|
@ -305,6 +312,7 @@ def test_generator_dict_5():
|
|||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
# Test tuple iterator
|
||||
|
||||
def test_generator_tuple_0():
|
||||
|
@ -323,6 +331,7 @@ def test_generator_tuple_0():
|
|||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
|
||||
|
||||
def test_generator_tuple_1():
|
||||
"""
|
||||
test generator tuple 1
|
||||
|
@ -342,6 +351,7 @@ def test_generator_tuple_1():
|
|||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
|
||||
def test_generator_tuple_2():
|
||||
"""
|
||||
test generator tuple 2
|
||||
|
@ -364,6 +374,7 @@ def test_generator_tuple_2():
|
|||
assert item1
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_tuple_3():
|
||||
"""
|
||||
test generator tuple 3
|
||||
|
@ -442,6 +453,7 @@ def test_generator_tuple_5():
|
|||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
# Test with repeat
|
||||
def test_generator_tuple_repeat_1():
|
||||
"""
|
||||
|
@ -536,6 +548,7 @@ def test_generator_tuple_repeat_repeat_2():
|
|||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_repeat_repeat_3():
|
||||
"""
|
||||
test generator tuple repeat repeat 3
|
||||
|
|
|
@ -149,7 +149,7 @@ def test_get_column_name_to_device():
|
|||
|
||||
|
||||
def test_get_column_name_voc():
|
||||
data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False)
|
||||
assert data.get_col_names() == ["image", "target"]
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ DATA_DIR = "../data/dataset/testVOC2012"
|
|||
|
||||
def test_noop_pserver():
|
||||
os.environ['MS_ROLE'] = 'MS_PSERVER'
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
num = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num += 1
|
||||
|
@ -32,7 +32,7 @@ def test_noop_pserver():
|
|||
|
||||
def test_noop_sched():
|
||||
os.environ['MS_ROLE'] = 'MS_SCHED'
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
num = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num += 1
|
||||
|
|
|
@ -42,8 +42,8 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False):
|
|||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
|
||||
|
||||
|
@ -108,8 +108,8 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False):
|
|||
logger.info("test_random_resized_crop_with_bbox_op_edge_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
|
||||
|
||||
|
@ -142,7 +142,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
|
|||
logger.info("test_random_resized_crop_with_bbox_op_invalid_c")
|
||||
|
||||
# Load dataset, only Augmented Dataset as test will raise ValueError
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
# If input range of scale is not in the order of (min, max), ValueError will be raised.
|
||||
|
@ -168,7 +168,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
|
|||
"""
|
||||
logger.info("test_random_resized_crop_with_bbox_op_invalid2_c")
|
||||
# Load dataset # only loading the to AugDataset as test will fail on this
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
|
||||
|
@ -195,13 +195,13 @@ def test_random_resized_crop_with_bbox_op_bad_c():
|
|||
logger.info("test_random_resized_crop_with_bbox_op_bad_c")
|
||||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
|
||||
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -39,8 +39,8 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
|
|||
logger.info("test_random_crop_with_bbox_op_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# define test OP with values to match existing Op UT
|
||||
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
|
||||
|
@ -101,8 +101,8 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
|
|||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# define test OP with values to match existing Op unit - test
|
||||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
|
||||
|
@ -138,8 +138,8 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
|
|||
logger.info("test_random_crop_with_bbox_op3_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# define test OP with values to match existing Op unit - test
|
||||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
|
||||
|
@ -168,8 +168,8 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
|
|||
logger.info("test_random_crop_with_bbox_op_edge_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
# define test OP with values to match existing Op unit - test
|
||||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
|
||||
|
@ -205,7 +205,7 @@ def test_random_crop_with_bbox_op_invalid_c():
|
|||
logger.info("test_random_crop_with_bbox_op_invalid_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
# define test OP with values to match existing Op unit - test
|
||||
|
@ -231,13 +231,13 @@ def test_random_crop_with_bbox_op_bad_c():
|
|||
logger.info("test_random_crop_with_bbox_op_bad_c")
|
||||
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
|
||||
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
@ -247,7 +247,7 @@ def test_random_crop_with_bbox_op_bad_padding():
|
|||
"""
|
||||
logger.info("test_random_crop_with_bbox_op_invalid_c")
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1)
|
||||
|
|
|
@ -37,11 +37,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False):
|
|||
logger.info("test_random_horizontal_flip_with_bbox_op_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
|
||||
|
||||
|
@ -102,11 +100,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False):
|
|||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(0.6)
|
||||
|
||||
|
@ -140,8 +136,8 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False):
|
|||
"""
|
||||
logger.info("test_horizontal_flip_with_bbox_valid_edge_c")
|
||||
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
|
||||
|
||||
|
@ -178,7 +174,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
|
|||
"""
|
||||
logger.info("test_random_horizontal_bbox_invalid_prob_c")
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
# Note: Valid range of prob should be [0.0, 1.0]
|
||||
|
@ -201,13 +197,13 @@ def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
|
|||
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -39,11 +39,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False):
|
|||
original_seed = config_get_set_seed(123)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomResizeWithBBox(100)
|
||||
|
||||
|
@ -120,11 +118,9 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False):
|
|||
box has dimensions as the image itself.
|
||||
"""
|
||||
logger.info("test_random_resize_with_bbox_op_edge_c")
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomResizeWithBBox(500)
|
||||
|
||||
|
@ -197,13 +193,13 @@ def test_random_resize_with_bbox_op_bad_c():
|
|||
logger.info("test_random_resize_with_bbox_op_bad_c")
|
||||
test_op = c_vision.RandomResizeWithBBox((400, 300))
|
||||
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -37,11 +37,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
|
|||
"""
|
||||
logger.info("test_random_vertical_flip_with_bbox_op_c")
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomVerticalFlipWithBBox(1)
|
||||
|
||||
|
@ -102,11 +100,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
|
|||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomVerticalFlipWithBBox(0.8)
|
||||
|
||||
|
@ -139,11 +135,9 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False):
|
|||
applied on dynamically generated edge case, expected to pass
|
||||
"""
|
||||
logger.info("test_random_vertical_flip_with_bbox_op_edge_c")
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.RandomVerticalFlipWithBBox(1)
|
||||
|
||||
|
@ -174,8 +168,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
|
|||
Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError
|
||||
"""
|
||||
logger.info("test_random_vertical_flip_with_bbox_op_invalid_c")
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
try:
|
||||
test_op = c_vision.RandomVerticalFlipWithBBox(2)
|
||||
|
@ -201,13 +194,13 @@ def test_random_vertical_flip_with_bbox_op_bad_c():
|
|||
logger.info("test_random_vertical_flip_with_bbox_op_bad_c")
|
||||
test_op = c_vision.RandomVerticalFlipWithBBox(1)
|
||||
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -39,11 +39,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False):
|
|||
logger.info("test_resize_with_bbox_op_voc_c")
|
||||
|
||||
# Load dataset
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.ResizeWithBBox(100)
|
||||
|
||||
|
@ -110,11 +108,9 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False):
|
|||
box has dimensions as the image itself.
|
||||
"""
|
||||
logger.info("test_resize_with_bbox_op_edge_c")
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
|
||||
decode=True, shuffle=False)
|
||||
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
||||
test_op = c_vision.ResizeWithBBox(500)
|
||||
|
||||
|
@ -163,13 +159,13 @@ def test_resize_with_bbox_op_bad_c():
|
|||
logger.info("test_resize_with_bbox_op_bad_c")
|
||||
test_op = c_vision.ResizeWithBBox((200, 300))
|
||||
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
||||
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from mindspore.dataset.vision import Inter
|
|||
|
||||
|
||||
|
||||
|
||||
def test_imagefolder(remove_json_files=True):
|
||||
"""
|
||||
Test simulating resnet50 dataset pipeline.
|
||||
|
@ -103,7 +104,7 @@ def test_mnist_dataset(remove_json_files=True):
|
|||
data_dir = "../data/dataset/testMnistData"
|
||||
ds.config.set_seed(1)
|
||||
|
||||
data1 = ds.MnistDataset(data_dir, 100)
|
||||
data1 = ds.MnistDataset(data_dir, num_samples=100)
|
||||
one_hot_encode = c.OneHot(10) # num_classes is input argument
|
||||
data1 = data1.map(input_columns="label", operations=one_hot_encode)
|
||||
|
||||
|
|
Loading…
Reference in New Issue