forked from mindspore-Ecosystem/mindspore
!20047 [assistant][ops]New operator implementation, include QmnistDataset
Merge pull request !20047 from Wangsong95/qmnist_dataset
This commit is contained in:
commit
2969688382
|
@ -102,6 +102,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
|
@ -1220,6 +1221,28 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
|
|||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
|
@ -248,6 +249,17 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode",
|
||||
"to create an QMnistNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, bool compat, py::handle sampler) {
|
||||
auto qmnist =
|
||||
std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(qmnist->ValidateParams());
|
||||
return qmnist;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
|
||||
"to create a RandomNode")
|
||||
|
|
|
@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
cityscapes_op.cc
|
||||
div2k_op.cc
|
||||
flickr_op.cc
|
||||
qmnist_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -81,7 +81,7 @@ class MnistOp : public MappableLeafOp {
|
|||
// \return DatasetName of the current Op
|
||||
virtual std::string DatasetName(bool upper = false) const { return upper ? "Mnist" : "mnist"; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Load a tensor row according to a pair
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param ImageLabelPair pair - <imagefile,label>
|
||||
|
@ -94,14 +94,14 @@ class MnistOp : public MappableLeafOp {
|
|||
// @param std::ifstream *image_reader - image file stream
|
||||
// @param uint32_t num_images - returns the number of images
|
||||
// @return Status The status code returned
|
||||
Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
|
||||
virtual Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
|
||||
|
||||
// Check label stream.
|
||||
// @param const std::string &file_name - label file name
|
||||
// @param std::ifstream *label_reader - label file stream
|
||||
// @param uint32_t num_labels - returns the number of labels
|
||||
// @return Status The status code returned
|
||||
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
|
||||
virtual Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
|
||||
|
||||
// Read 4 bytes of data from a file stream.
|
||||
// @param std::ifstream *reader - file stream to read
|
||||
|
@ -118,7 +118,7 @@ class MnistOp : public MappableLeafOp {
|
|||
// @param std::ifstream *label_reader - label file stream
|
||||
// @param int64_t read_num - number of image to read
|
||||
// @return Status The status code returned
|
||||
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
|
||||
virtual Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
|
||||
|
||||
// Parse all mnist dataset files
|
||||
// @return Status The status code returned
|
||||
|
@ -126,7 +126,7 @@ class MnistOp : public MappableLeafOp {
|
|||
|
||||
// Read all files in the directory
|
||||
// @return Status The status code returned
|
||||
Status WalkAllFiles();
|
||||
virtual Status WalkAllFiles();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status The status code returned
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/qmnist_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const int32_t kQMnistLabelFileMagicNumber = 3074;
|
||||
const int32_t kQMnistImageRows = 28;
|
||||
const int32_t kQMnistImageCols = 28;
|
||||
const int32_t kQMnistLabelLength = 8;
|
||||
|
||||
QMnistOp::QMnistOp(const std::string &folder_path, const std::string &usage, bool compat,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t num_workers,
|
||||
int32_t queue_size)
|
||||
: MnistOp(usage, num_workers, folder_path, queue_size, std::move(data_schema), std::move(sampler)),
|
||||
compat_(compat) {}
|
||||
|
||||
void QMnistOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\n"
|
||||
<< DatasetName(true) << " directory: " << folder_path_ << "\nUsage: " << usage_
|
||||
<< "\nCompat: " << (compat_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image, label) using 1 MnistLabelPair or QMnistImageInfoPair.
|
||||
Status QMnistOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
if (compat_) {
|
||||
MnistLabelPair qmnist_pair = image_label_pairs_[row_id];
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(qmnist_pair.second, &label));
|
||||
} else {
|
||||
QMnistImageInfoPair qmnist_pair = image_info_pairs_[row_id];
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.second, &label));
|
||||
}
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
trow->setPath({image_path_[row_id], label_path_[row_id]});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
|
||||
// compat does not affect the count result, so set it to true default.
|
||||
auto op =
|
||||
std::make_shared<QMnistOp>(dir, usage, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
|
||||
|
||||
// the logic of counting the number of samples
|
||||
RETURN_IF_NOT_OK(op->WalkAllFiles());
|
||||
for (size_t i = 0; i < op->image_names_.size(); ++i) {
|
||||
std::ifstream image_reader;
|
||||
image_reader.open(op->image_names_[i], std::ios::binary);
|
||||
std::ifstream label_reader;
|
||||
label_reader.open(op->label_names_[i], std::ios::binary);
|
||||
|
||||
uint32_t num_images;
|
||||
RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images));
|
||||
uint32_t num_labels;
|
||||
RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
|
||||
"Invalid data, num of images is not equal to num of labels.");
|
||||
|
||||
if (usage == "test10k") {
|
||||
// only use the first 10k samples and drop the last 50k samples
|
||||
num_images = 10000;
|
||||
num_labels = 10000;
|
||||
} else if (usage == "test50k") {
|
||||
// only use the last 50k samples and drop the first 10k samples
|
||||
num_images = 50000;
|
||||
num_labels = 50000;
|
||||
}
|
||||
|
||||
*count = *count + num_images;
|
||||
|
||||
// Close the readers
|
||||
image_reader.close();
|
||||
label_reader.close();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistOp::WalkAllFiles() {
|
||||
const std::string image_ext = "images-idx3-ubyte";
|
||||
const std::string label_ext = "labels-idx2-int";
|
||||
const std::string train_prefix = "qmnist-train";
|
||||
const std::string test_prefix = "qmnist-test";
|
||||
const std::string nist_prefix = "xnist";
|
||||
|
||||
auto real_folder_path = FileUtils::GetRealPath(folder_path_.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
|
||||
Path root_dir(real_folder_path.value());
|
||||
|
||||
if (usage_ == "train") {
|
||||
image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString());
|
||||
} else if (usage_ == "test" || usage_ == "test10k" || usage_ == "test50k") {
|
||||
image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString());
|
||||
} else if (usage_ == "nist") {
|
||||
image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString());
|
||||
} else if (usage_ == "all") {
|
||||
image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString());
|
||||
image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString());
|
||||
image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString());
|
||||
label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString());
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(),
|
||||
"Invalid data, num of images does not equal to num of labels.");
|
||||
|
||||
for (size_t i = 0; i < image_names_.size(); i++) {
|
||||
Path file_path(image_names_[i]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(),
|
||||
"Failed to find " + DatasetName() + " image file: " + file_path.ToString());
|
||||
MS_LOG(INFO) << DatasetName(true) << " operator found image file at " << file_path.ToString() << ".";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < label_names_.size(); i++) {
|
||||
Path file_path(label_names_[i]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(),
|
||||
"Failed to find " + DatasetName() + " label file: " + file_path.ToString());
|
||||
MS_LOG(INFO) << DatasetName(true) << " operator found label file at " << file_path.ToString() << ".";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) {
|
||||
RETURN_UNEXPECTED_IF_NULL(image_reader);
|
||||
RETURN_UNEXPECTED_IF_NULL(label_reader);
|
||||
uint32_t num_images, num_labels;
|
||||
RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images));
|
||||
RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
|
||||
"Invalid data, num_images is not equal to num_labels. Ensure data file is not damaged.");
|
||||
|
||||
// The image size of the QMNIST dataset is fixed at [28,28]
|
||||
int64_t image_size = kQMnistImageRows * kQMnistImageCols;
|
||||
int64_t label_length = kQMnistLabelLength;
|
||||
|
||||
if (usage_ == "test10k") {
|
||||
// only use the first 10k samples and drop the last 50k samples
|
||||
num_images = 10000;
|
||||
num_labels = 10000;
|
||||
} else if (usage_ == "test50k") {
|
||||
num_images = 50000;
|
||||
num_labels = 50000;
|
||||
// skip the first 10k samples for ifstream reader
|
||||
(void)image_reader->ignore(image_size * 10000);
|
||||
(void)label_reader->ignore(label_length * 10000 * 4);
|
||||
}
|
||||
|
||||
auto images_buf = std::make_unique<char[]>(image_size * num_images);
|
||||
auto labels_buf = std::make_unique<uint32_t[]>(label_length * num_labels);
|
||||
if (images_buf == nullptr || labels_buf == nullptr) {
|
||||
std::string err_msg = "[Internal ERROR] Failed to allocate memory for " + DatasetName() + " buffer.";
|
||||
MS_LOG(ERROR) << err_msg.c_str();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
(void)image_reader->read(images_buf.get(), image_size * num_images);
|
||||
if (image_reader->fail()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " image: " + image_names_[index] +
|
||||
", size:" + std::to_string(image_size * num_images) +
|
||||
". Ensure data file is not damaged.");
|
||||
}
|
||||
// uint32_t use 4 bytes in memory
|
||||
(void)label_reader->read(reinterpret_cast<char *>(labels_buf.get()), label_length * num_labels * 4);
|
||||
if (label_reader->fail()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " label:" + label_names_[index] +
|
||||
", size: " + std::to_string(label_length * num_labels) +
|
||||
". Ensure data file is not damaged.");
|
||||
}
|
||||
TensorShape image_tensor_shape = TensorShape({kQMnistImageRows, kQMnistImageCols, 1});
|
||||
TensorShape label_tensor_shape = TensorShape({kQMnistLabelLength});
|
||||
for (int64_t data_index = 0; data_index != num_images; data_index++) {
|
||||
auto image = &images_buf[data_index * image_size];
|
||||
for (int64_t image_index = 0; image_index < image_size; image_index++) {
|
||||
image[image_index] = (image[image_index] == 0) ? 0 : 255;
|
||||
}
|
||||
std::shared_ptr<Tensor> image_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
|
||||
reinterpret_cast<unsigned char *>(image), &image_tensor));
|
||||
|
||||
auto label = &labels_buf[data_index * label_length];
|
||||
for (int64_t label_index = 0; label_index < label_length; label_index++) {
|
||||
label[label_index] = SwapEndian(label[label_index]);
|
||||
}
|
||||
std::shared_ptr<Tensor> label_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(label_tensor_shape, data_schema_->Column(1).Type(),
|
||||
reinterpret_cast<unsigned char *>(label), &label_tensor));
|
||||
|
||||
image_info_pairs_.emplace_back(std::make_pair(image_tensor, label_tensor));
|
||||
image_label_pairs_.emplace_back(std::make_pair(image_tensor, label[0]));
|
||||
image_path_.push_back(image_names_[index]);
|
||||
label_path_.push_back(label_names_[index]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
|
||||
RETURN_UNEXPECTED_IF_NULL(label_reader);
|
||||
RETURN_UNEXPECTED_IF_NULL(num_labels);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(),
|
||||
"Invalid file, failed to open " + DatasetName() + " label file: " + file_name);
|
||||
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
|
||||
(void)label_reader->seekg(0, std::ios::beg);
|
||||
// The first 12 bytes of the label file are type, number and length
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 12, "Invalid file, " + DatasetName() + " file is corrupted: " + file_name);
|
||||
uint32_t magic_number;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kQMnistLabelFileMagicNumber,
|
||||
"Invalid file, this is not the " + DatasetName() + " label file: " + file_name);
|
||||
uint32_t num_items;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
|
||||
uint32_t length;
|
||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &length));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(length == kQMnistLabelLength, "Invalid data, length of labels is not equal to 8.");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 12) == num_items * kQMnistLabelLength * 4,
|
||||
"Invalid data, number of labels is wrong.");
|
||||
*num_labels = num_items;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using QMnistImageInfoPair = std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>;
|
||||
|
||||
class QMnistOp : public MnistOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param const std::string &folder_path - dir directory of QMNIST data file.
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or
|
||||
// 'all'.
|
||||
// @param bool compat - Compatibility with Mnist.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the QMNIST dataset.
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells QMnistOp what to read.
|
||||
// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
// @param int32_t queue_size - connector queue size.
|
||||
QMnistOp(const std::string &folder_path, const std::string &usage, bool compat,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t num_workers,
|
||||
int32_t queue_size);
|
||||
|
||||
// Destructor.
|
||||
~QMnistOp() = default;
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "QMnistOp"; }
|
||||
|
||||
// DatasetName name getter
|
||||
// \return std::string - DatasetName of the current Op
|
||||
std::string DatasetName(bool upper = false) const { return upper ? "QMnist" : "qmnist"; }
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the QMNIST dataset.
|
||||
// @param const std::string &dir - path to the QMNIST directory.
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or
|
||||
// 'all'.
|
||||
// @param int64_t *count - output arg that will hold the actual dataset size.
|
||||
// @return Status -The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
|
||||
|
||||
private:
|
||||
// Load a tensor row according to a pair.
|
||||
// @param row_id_type row_id - id for this tensor row.
|
||||
// @param TensorRow row - image & label read into this tensor row.
|
||||
// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
// Get needed files in the folder_path_.
|
||||
// @return Status - The status code returned.
|
||||
Status WalkAllFiles() override;
|
||||
|
||||
// Read images and labels from the file stream.
|
||||
// @param std::ifstream *image_reader - image file stream.
|
||||
// @param std::ifstream *label_reader - label file stream.
|
||||
// @param size_t index - the index of file that is reading.
|
||||
// @return Status The status code returned.
|
||||
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) override;
|
||||
|
||||
// Check label stream.
|
||||
// @param const std::string &file_name - label file name.
|
||||
// @param std::ifstream *label_reader - label file stream.
|
||||
// @param uint32_t num_labels - returns the number of labels.
|
||||
// @return Status The status code returned.
|
||||
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) override;
|
||||
|
||||
const bool compat_; // compatible with mnist
|
||||
|
||||
std::vector<QMnistImageInfoPair> image_info_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_
|
|
@ -89,6 +89,7 @@ constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
|||
constexpr char kManifestNode[] = "ManifestDataset";
|
||||
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||
constexpr char kMnistNode[] = "MnistDataset";
|
||||
constexpr char kQMnistNode[] = "QMnistDataset";
|
||||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kSBUNode[] = "SBUDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
|
|
|
@ -17,6 +17,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
manifest_node.cc
|
||||
minddata_node.cc
|
||||
mnist_node.cc
|
||||
qmnist_node.cc
|
||||
random_node.cc
|
||||
sbu_node.cc
|
||||
text_file_node.cc
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/qmnist_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
compat_(compat),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> QMnistNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<QMnistNode>(dataset_dir_, usage_, compat_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void QMnistNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(dataset dir: " + dataset_dir_ + ", usage: " + usage_ +
|
||||
", compat: " + (compat_ ? "true" : "false") + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status QMnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("QMnistNode", usage_, {"train", "test", "test10k", "test50k", "nist", "all"}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
if (compat_) {
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt),
|
||||
num_workers_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status QMnistNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status QMnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(QMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMnistNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["compat"] = compat_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("compat") != json_obj.end(), "Failed to find compat");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
bool compat = json_obj["compat"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<QMnistNode>(dataset_dir, usage, compat, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class QMnistNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~QMnistNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kQMnistNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const bool Compat() const { return compat_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool compat_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_
|
|
@ -2273,6 +2273,90 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
|
|||
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class QMnistDataset
|
||||
/// \brief A source dataset that reads and parses QMNIST dataset.
|
||||
class QMnistDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of QMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
|
||||
/// \param[in] compat Whether the label for each example is class number (compat=true)
|
||||
/// or the full QMNIST information (compat=false).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of QMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
|
||||
/// \param[in] compat Whether the label for each example is class number (compat=true)
|
||||
/// or the full QMNIST information (compat=false).
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of QMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
|
||||
/// \param[in] compat Whether the label for each example is class number (compat=true)
|
||||
/// or the full QMNIST information (compat=false).
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of QMnistDataset.
|
||||
~QMnistDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a QMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all" (default = "all").
|
||||
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information
|
||||
/// (default = true).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the QMnistDataset.
|
||||
inline std::shared_ptr<QMnistDataset> QMnist(
|
||||
const std::string &dataset_dir, const std::string &usage = "all", bool compat = true,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a QMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
|
||||
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the QMnistDataset.
|
||||
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage, bool compat,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a QMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
|
||||
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the QMnistDataset.
|
||||
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage, bool compat,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a ConcatDataset.
|
||||
/// \note Reload "+" operator to concat two datasets.
|
||||
/// \param[in] datasets1 Shared pointer to the first dataset to be concatenated.
|
||||
|
@ -2565,15 +2649,14 @@ class USPSDataset : public Dataset {
|
|||
public:
|
||||
/// \brief Constructor of USPSDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||
/// (Default = 0 means all samples).
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
|
||||
/// \param[in] usage Usage of USPS, can be "train", "test" or "all".
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kFiles - Shuffle files only.
|
||||
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified (Default = 0).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
|
|
|
@ -44,6 +44,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class ManifestDataset;
|
||||
friend class MindDataDataset;
|
||||
friend class MnistDataset;
|
||||
friend class QMnistDataset;
|
||||
friend class RandomDataDataset;
|
||||
friend class SBUDataset;
|
||||
friend class TextFileDataset;
|
||||
|
|
|
@ -66,7 +66,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset
|
||||
check_sbu_dataset, check_qmnist_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -3441,6 +3441,131 @@ class MnistDataset(MappableDataset):
|
|||
return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class QMnistDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the QMNIST dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is a scalar when `compat` is True else a tensor both of the uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be `train`, `test`, `test10k`, `test50k`, `nist`
|
||||
or `all` (default=None, will read all samples).
|
||||
compat (bool, optional): Whether the label for each example is class number (compat=True) or the full QMNIST
|
||||
information (compat=False) (default=True).
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> qmnist_dataset_dir = "/path/to/qmnist_dataset_directory"
|
||||
>>>
|
||||
>>> # Read 3 samples from QMNIST train dataset
|
||||
>>> dataset = ds.QMnistDataset(dataset_dir=qmnist_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # Note: In QMNIST dataset, each dictionary has keys "image" and "label"
|
||||
|
||||
About QMNIST dataset:
|
||||
|
||||
The QMNIST dataset was generated from the original data found in the NIST Special Database 19 with the goal to
|
||||
match the MNIST preprocessing as closely as possible.
|
||||
Through an iterative process, researchers tried to generate an additional 50k images of MNIST-like data.
|
||||
They started with a reconstruction process given in the paper and used the Hungarian algorithm to find the best
|
||||
matches between the original MNIST samples and their reconstructed samples.
|
||||
|
||||
Here is the original QMNIST dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── qmnist_dataset_dir
|
||||
├── qmnist-train-images-idx3-ubyte
|
||||
├── qmnist-train-labels-idx2-int
|
||||
├── qmnist-test-images-idx3-ubyte
|
||||
├── qmnist-test-labels-idx2-int
|
||||
├── xnist-images-idx3-ubyte
|
||||
└── xnist-labels-idx2-int
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@incollection{qmnist-2019,
|
||||
title = "Cold Case: The Lost MNIST Digits",
|
||||
author = "Chhavi Yadav and L\'{e}on Bottou",\
|
||||
booktitle = {Advances in Neural Information Processing Systems 32},
|
||||
year = {2019},
|
||||
publisher = {Curran Associates, Inc.},
|
||||
}
|
||||
"""
|
||||
|
||||
@check_qmnist_dataset
|
||||
def __init__(self, dataset_dir, usage=None, compat=True, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.compat = compat
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.QMnistNode(self.dataset_dir, self.usage, self.compat, self.sampler)
|
||||
|
||||
|
||||
class MindDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing MindRecord dataset.
|
||||
|
|
|
@ -92,6 +92,36 @@ def check_mnist_cifar_dataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_qmnist_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'compat']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_manifestdataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_minddata_test.cc
|
||||
c_api_dataset_ops_test.cc
|
||||
c_api_dataset_qmnist_test.cc
|
||||
c_api_dataset_randomdata_test.cc
|
||||
c_api_dataset_save.cc
|
||||
c_api_dataset_sbu_test.cc
|
||||
|
|
|
@ -0,0 +1,343 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistTrainDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDataset.";
|
||||
|
||||
// Create a QMNIST Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistTestDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTestDataset.";
|
||||
|
||||
// Create a QMNIST Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "test", true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistNistDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistNistDataset.";
|
||||
|
||||
// Create a QMNIST Nist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "nist", true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistAllDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistAllDataset.";
|
||||
|
||||
// Create a QMNIST All Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "all", true, std::make_shared<RandomSampler>(false, 20));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistCompatDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistCompatDataset.";
|
||||
|
||||
// Create a QMNIST All Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "all", false, std::make_shared<RandomSampler>(false, 20));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
MS_LOG(INFO) << "Tensor label shape: " << label.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetWithPipeline.";
|
||||
|
||||
// Create two QMNIST Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds1 = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
|
||||
std::shared_ptr<Dataset> ds2 = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestGetQMnistDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetQMnistTrainDatasetSize.";
|
||||
|
||||
// Create a QMNIST Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetGetters.";
|
||||
|
||||
// Create a QMNIST Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<28,28,1>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistDataFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataFail.";
|
||||
|
||||
// Create a QMNIST Dataset
|
||||
std::shared_ptr<Dataset> ds = QMnist("", "train", true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid QMNIST input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistDataWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithInvalidUsageFail.";
|
||||
|
||||
// Create a QMNIST Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "validation", true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid QMNIST input, validation is not a valid usage
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestQMnistDataWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithNullSamplerFail.";
|
||||
|
||||
// Create a QMNIST Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
|
||||
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid QMNIST input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,343 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test QMnistDataset operator
|
||||
"""
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testQMnistData"
|
||||
|
||||
|
||||
def load_qmnist(path, usage, compat=True):
|
||||
"""
|
||||
load QMNIST data
|
||||
"""
|
||||
image_path = []
|
||||
label_path = []
|
||||
image_ext = "images-idx3-ubyte"
|
||||
label_ext = "labels-idx2-int"
|
||||
train_prefix = "qmnist-train"
|
||||
test_prefix = "qmnist-test"
|
||||
nist_prefix = "xnist"
|
||||
assert usage in ["train", "test", "nist", "all"]
|
||||
if usage == "train":
|
||||
image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
|
||||
elif usage == "test":
|
||||
image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
|
||||
elif usage == "nist":
|
||||
image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
|
||||
elif usage == "all":
|
||||
image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
|
||||
image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
|
||||
image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
|
||||
label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
|
||||
|
||||
assert len(image_path) == len(label_path)
|
||||
|
||||
images = []
|
||||
labels = []
|
||||
for i, _ in enumerate(image_path):
|
||||
with open(image_path[i], 'rb') as image_file:
|
||||
image_file.read(16)
|
||||
image = np.fromfile(image_file, dtype=np.uint8)
|
||||
image = image.reshape(-1, 28, 28, 1)
|
||||
image[image > 0] = 255 # Perform binarization to maintain consistency with our API
|
||||
images.append(image)
|
||||
with open(label_path[i], 'rb') as label_file:
|
||||
label_file.read(12)
|
||||
label = np.fromfile(label_file, dtype='>u4')
|
||||
label = label.reshape(-1, 8)
|
||||
labels.append(label)
|
||||
|
||||
images = np.concatenate(images, 0)
|
||||
labels = np.concatenate(labels, 0)
|
||||
if compat:
|
||||
return images, labels[:, 0]
|
||||
return images, labels
|
||||
|
||||
|
||||
def visualize_dataset(images, labels):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
num_samples = len(images)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(labels[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_qmnist_content_check():
|
||||
"""
|
||||
Validate QMnistDataset image readings
|
||||
"""
|
||||
logger.info("Test QMnistDataset Op with content check")
|
||||
for usage in ["train", "test", "nist", "all"]:
|
||||
data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False)
|
||||
images, labels = load_qmnist(DATA_DIR, usage, True)
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
image_list, label_list = [], []
|
||||
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_list.append(data["image"])
|
||||
label_list.append("label {}".format(data["label"]))
|
||||
np.testing.assert_array_equal(data["image"], images[i])
|
||||
np.testing.assert_array_equal(data["label"], labels[i])
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
for usage in ["train", "test", "nist", "all"]:
|
||||
data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False)
|
||||
images, labels = load_qmnist(DATA_DIR, usage, False)
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
image_list, label_list = [], []
|
||||
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_list.append(data["image"])
|
||||
label_list.append("label {}".format(data["label"]))
|
||||
np.testing.assert_array_equal(data["image"], images[i])
|
||||
np.testing.assert_array_equal(data["label"], labels[i])
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_qmnist_basic():
|
||||
"""
|
||||
Validate QMnistDataset
|
||||
"""
|
||||
logger.info("Test QMnistDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
data1 = ds.QMnistDataset(DATA_DIR, "train", True)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 10
|
||||
|
||||
# case 2: test num_samples
|
||||
data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 5
|
||||
|
||||
# case 3: test repeat
|
||||
data3 = ds.QMnistDataset(DATA_DIR, "train", True)
|
||||
data3 = data3.repeat(5)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 50
|
||||
|
||||
# case 4: test batch with drop_remainder=False
|
||||
data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
|
||||
assert data4.get_dataset_size() == 10
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
|
||||
assert data4.get_dataset_size() == 2
|
||||
assert data4.get_batch_size() == 7
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 2
|
||||
|
||||
# case 5: test batch with drop_remainder=True
|
||||
data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
|
||||
assert data5.get_dataset_size() == 10
|
||||
assert data5.get_batch_size() == 1
|
||||
data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data5.get_dataset_size() == 3
|
||||
assert data5.get_batch_size() == 3
|
||||
num_iter5 = 0
|
||||
for _ in data5.create_dict_iterator(num_epochs=1):
|
||||
num_iter5 += 1
|
||||
assert num_iter5 == 3
|
||||
|
||||
# case 6: test get_col_names
|
||||
dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
|
||||
assert dataset.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_qmnist_pk_sampler():
|
||||
"""
|
||||
Test QMnistDataset with PKSampler
|
||||
"""
|
||||
logger.info("Test QMnistDataset Op with PKSampler")
|
||||
golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
sampler = ds.PKSampler(10)
|
||||
data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler)
|
||||
num_iter = 0
|
||||
label_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
label_list.append(item["label"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(golden, label_list)
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_qmnist_sequential_sampler():
|
||||
"""
|
||||
Test QMnistDataset with SequentialSampler
|
||||
"""
|
||||
logger.info("Test QMnistDataset Op with SequentialSampler")
|
||||
num_samples = 10
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler)
|
||||
data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples)
|
||||
label_list1, label_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
|
||||
label_list1.append(item1["label"].asnumpy())
|
||||
label_list2.append(item2["label"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(label_list1, label_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_qmnist_exception():
|
||||
"""
|
||||
Test error cases for QMnistDataset
|
||||
"""
|
||||
logger.info("Test error cases for MnistDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
error_msg_8 = "The corresponding data files"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.QMnistDataset(DATA_DIR, "train", True)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.QMnistDataset(DATA_DIR, "train", True)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.QMnistDataset(DATA_DIR, "train", True)
|
||||
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_qmnist_visualize(plot=False):
|
||||
"""
|
||||
Visualize QMnistDataset results
|
||||
"""
|
||||
logger.info("Test QMnistDataset visualization")
|
||||
|
||||
data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, label_list = [], []
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
label = item["label"]
|
||||
image_list.append(image)
|
||||
label_list.append("label {}".format(label))
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.shape == (28, 28, 1)
|
||||
assert image.dtype == np.uint8
|
||||
assert label.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
if plot:
|
||||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_qmnist_usage():
|
||||
"""
|
||||
Validate QMnistDataset image readings
|
||||
"""
|
||||
logger.info("Test QMnistDataset usage flag")
|
||||
|
||||
def test_config(usage, path=None):
|
||||
path = DATA_DIR if path is None else path
|
||||
try:
|
||||
data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("train") == 10
|
||||
assert test_config("test") == 10
|
||||
assert test_config("nist") == 10
|
||||
assert test_config("all") == 30
|
||||
assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in\
|
||||
test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_qmnist_content_check()
|
||||
test_qmnist_basic()
|
||||
test_qmnist_pk_sampler()
|
||||
test_qmnist_sequential_sampler()
|
||||
test_qmnist_exception()
|
||||
test_qmnist_visualize(plot=True)
|
||||
test_qmnist_usage()
|
Loading…
Reference in New Issue