!20047 [assistant][ops]New operator implementation, include QmnistDataset

Merge pull request !20047 from Wangsong95/qmnist_dataset
This commit is contained in:
i-robot 2021-09-27 01:36:08 +00:00 committed by Gitee
commit 2969688382
23 changed files with 1632 additions and 11 deletions

View File

@ -102,6 +102,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
@ -1220,6 +1221,28 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
}
#ifndef ENABLE_ANDROID
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -44,6 +44,7 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
@ -248,6 +249,17 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
(void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode",
"to create an QMnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, bool compat, py::handle sampler) {
auto qmnist =
std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(qmnist->ValidateParams());
return qmnist;
}));
}));
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
"to create a RandomNode")

View File

@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
cityscapes_op.cc
div2k_op.cc
flickr_op.cc
qmnist_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

View File

@ -81,7 +81,7 @@ class MnistOp : public MappableLeafOp {
// \return DatasetName of the current Op
virtual std::string DatasetName(bool upper = false) const { return upper ? "Mnist" : "mnist"; }
private:
protected:
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
@ -94,14 +94,14 @@ class MnistOp : public MappableLeafOp {
// @param std::ifstream *image_reader - image file stream
// @param uint32_t num_images - returns the number of images
// @return Status The status code returned
Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
virtual Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
// Check label stream.
// @param const std::string &file_name - label file name
// @param std::ifstream *label_reader - label file stream
// @param uint32_t num_labels - returns the number of labels
// @return Status The status code returned
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
virtual Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
// Read 4 bytes of data from a file stream.
// @param std::ifstream *reader - file stream to read
@ -118,7 +118,7 @@ class MnistOp : public MappableLeafOp {
// @param std::ifstream *label_reader - label file stream
// @param int64_t read_num - number of image to read
// @return Status The status code returned
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
virtual Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
// Parse all mnist dataset files
// @return Status The status code returned
@ -126,7 +126,7 @@ class MnistOp : public MappableLeafOp {
// Read all files in the directory
// @return Status The status code returned
Status WalkAllFiles();
virtual Status WalkAllFiles();
// Called first when function is called
// @return Status The status code returned

View File

@ -0,0 +1,283 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/qmnist_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
const int32_t kQMnistLabelFileMagicNumber = 3074;
const int32_t kQMnistImageRows = 28;
const int32_t kQMnistImageCols = 28;
const int32_t kQMnistLabelLength = 8;
QMnistOp::QMnistOp(const std::string &folder_path, const std::string &usage, bool compat,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t num_workers,
int32_t queue_size)
: MnistOp(usage, num_workers, folder_path, queue_size, std::move(data_schema), std::move(sampler)),
compat_(compat) {}
void QMnistOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\n"
<< DatasetName(true) << " directory: " << folder_path_ << "\nUsage: " << usage_
<< "\nCompat: " << (compat_ ? "yes" : "no") << "\n\n";
}
}
// Load 1 TensorRow (image, label) using 1 MnistLabelPair or QMnistImageInfoPair.
Status QMnistOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
std::shared_ptr<Tensor> image, label;
if (compat_) {
MnistLabelPair qmnist_pair = image_label_pairs_[row_id];
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image));
RETURN_IF_NOT_OK(Tensor::CreateScalar(qmnist_pair.second, &label));
} else {
QMnistImageInfoPair qmnist_pair = image_info_pairs_[row_id];
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image));
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.second, &label));
}
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({image_path_[row_id], label_path_[row_id]});
return Status::OK();
}
Status QMnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connector_size = cfg->op_connector_size();
// compat does not affect the count result, so set it to true default.
auto op =
std::make_shared<QMnistOp>(dir, usage, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
// the logic of counting the number of samples
RETURN_IF_NOT_OK(op->WalkAllFiles());
for (size_t i = 0; i < op->image_names_.size(); ++i) {
std::ifstream image_reader;
image_reader.open(op->image_names_[i], std::ios::binary);
std::ifstream label_reader;
label_reader.open(op->label_names_[i], std::ios::binary);
uint32_t num_images;
RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images));
uint32_t num_labels;
RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels));
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
"Invalid data, num of images is not equal to num of labels.");
if (usage == "test10k") {
// only use the first 10k samples and drop the last 50k samples
num_images = 10000;
num_labels = 10000;
} else if (usage == "test50k") {
// only use the last 50k samples and drop the first 10k samples
num_images = 50000;
num_labels = 50000;
}
*count = *count + num_images;
// Close the readers
image_reader.close();
label_reader.close();
}
return Status::OK();
}
Status QMnistOp::WalkAllFiles() {
const std::string image_ext = "images-idx3-ubyte";
const std::string label_ext = "labels-idx2-int";
const std::string train_prefix = "qmnist-train";
const std::string test_prefix = "qmnist-test";
const std::string nist_prefix = "xnist";
auto real_folder_path = FileUtils::GetRealPath(folder_path_.data());
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
Path root_dir(real_folder_path.value());
if (usage_ == "train") {
image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString());
} else if (usage_ == "test" || usage_ == "test10k" || usage_ == "test50k") {
image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString());
} else if (usage_ == "nist") {
image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString());
} else if (usage_ == "all") {
image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString());
image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString());
image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString());
label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString());
}
CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(),
"Invalid data, num of images does not equal to num of labels.");
for (size_t i = 0; i < image_names_.size(); i++) {
Path file_path(image_names_[i]);
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(),
"Failed to find " + DatasetName() + " image file: " + file_path.ToString());
MS_LOG(INFO) << DatasetName(true) << " operator found image file at " << file_path.ToString() << ".";
}
for (size_t i = 0; i < label_names_.size(); i++) {
Path file_path(label_names_[i]);
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(),
"Failed to find " + DatasetName() + " label file: " + file_path.ToString());
MS_LOG(INFO) << DatasetName(true) << " operator found label file at " << file_path.ToString() << ".";
}
return Status::OK();
}
Status QMnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) {
RETURN_UNEXPECTED_IF_NULL(image_reader);
RETURN_UNEXPECTED_IF_NULL(label_reader);
uint32_t num_images, num_labels;
RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images));
RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels));
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
"Invalid data, num_images is not equal to num_labels. Ensure data file is not damaged.");
// The image size of the QMNIST dataset is fixed at [28,28]
int64_t image_size = kQMnistImageRows * kQMnistImageCols;
int64_t label_length = kQMnistLabelLength;
if (usage_ == "test10k") {
// only use the first 10k samples and drop the last 50k samples
num_images = 10000;
num_labels = 10000;
} else if (usage_ == "test50k") {
num_images = 50000;
num_labels = 50000;
// skip the first 10k samples for ifstream reader
(void)image_reader->ignore(image_size * 10000);
(void)label_reader->ignore(label_length * 10000 * 4);
}
auto images_buf = std::make_unique<char[]>(image_size * num_images);
auto labels_buf = std::make_unique<uint32_t[]>(label_length * num_labels);
if (images_buf == nullptr || labels_buf == nullptr) {
std::string err_msg = "[Internal ERROR] Failed to allocate memory for " + DatasetName() + " buffer.";
MS_LOG(ERROR) << err_msg.c_str();
RETURN_STATUS_UNEXPECTED(err_msg);
}
(void)image_reader->read(images_buf.get(), image_size * num_images);
if (image_reader->fail()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " image: " + image_names_[index] +
", size:" + std::to_string(image_size * num_images) +
". Ensure data file is not damaged.");
}
// uint32_t use 4 bytes in memory
(void)label_reader->read(reinterpret_cast<char *>(labels_buf.get()), label_length * num_labels * 4);
if (label_reader->fail()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " label:" + label_names_[index] +
", size: " + std::to_string(label_length * num_labels) +
". Ensure data file is not damaged.");
}
TensorShape image_tensor_shape = TensorShape({kQMnistImageRows, kQMnistImageCols, 1});
TensorShape label_tensor_shape = TensorShape({kQMnistLabelLength});
for (int64_t data_index = 0; data_index != num_images; data_index++) {
auto image = &images_buf[data_index * image_size];
for (int64_t image_index = 0; image_index < image_size; image_index++) {
image[image_index] = (image[image_index] == 0) ? 0 : 255;
}
std::shared_ptr<Tensor> image_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
reinterpret_cast<unsigned char *>(image), &image_tensor));
auto label = &labels_buf[data_index * label_length];
for (int64_t label_index = 0; label_index < label_length; label_index++) {
label[label_index] = SwapEndian(label[label_index]);
}
std::shared_ptr<Tensor> label_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(label_tensor_shape, data_schema_->Column(1).Type(),
reinterpret_cast<unsigned char *>(label), &label_tensor));
image_info_pairs_.emplace_back(std::make_pair(image_tensor, label_tensor));
image_label_pairs_.emplace_back(std::make_pair(image_tensor, label[0]));
image_path_.push_back(image_names_[index]);
label_path_.push_back(label_names_[index]);
}
return Status::OK();
}
Status QMnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
RETURN_UNEXPECTED_IF_NULL(label_reader);
RETURN_UNEXPECTED_IF_NULL(num_labels);
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(),
"Invalid file, failed to open " + DatasetName() + " label file: " + file_name);
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
(void)label_reader->seekg(0, std::ios::beg);
// The first 12 bytes of the label file are type, number and length
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 12, "Invalid file, " + DatasetName() + " file is corrupted: " + file_name);
uint32_t magic_number;
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kQMnistLabelFileMagicNumber,
"Invalid file, this is not the " + DatasetName() + " label file: " + file_name);
uint32_t num_items;
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
uint32_t length;
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &length));
CHECK_FAIL_RETURN_UNEXPECTED(length == kQMnistLabelLength, "Invalid data, length of labels is not equal to 8.");
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 12) == num_items * kQMnistLabelLength * 4,
"Invalid data, number of labels is wrong.");
*num_labels = num_items;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
using QMnistImageInfoPair = std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>;
class QMnistOp : public MnistOp {
public:
// Constructor.
// @param const std::string &folder_path - dir directory of QMNIST data file.
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or
// 'all'.
// @param bool compat - Compatibility with Mnist.
// @param std::unique_ptr<DataSchema> data_schema - the schema of the QMNIST dataset.
// @param td::unique_ptr<Sampler> sampler - sampler tells QMnistOp what to read.
// @param int32_t num_workers - number of workers reading images in parallel.
// @param int32_t queue_size - connector queue size.
QMnistOp(const std::string &folder_path, const std::string &usage, bool compat,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t num_workers,
int32_t queue_size);
// Destructor.
~QMnistOp() = default;
// Op name getter.
// @return std::string - Name of the current Op.
std::string Name() const override { return "QMnistOp"; }
// DatasetName name getter
// \return std::string - DatasetName of the current Op
std::string DatasetName(bool upper = false) const { return upper ? "QMnist" : "qmnist"; }
// A print method typically used for debugging.
// @param std::ostream &out - out stream.
// @param bool show_all - whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the QMNIST dataset.
// @param const std::string &dir - path to the QMNIST directory.
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or
// 'all'.
// @param int64_t *count - output arg that will hold the actual dataset size.
// @return Status -The status code returned.
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
private:
// Load a tensor row according to a pair.
// @param row_id_type row_id - id for this tensor row.
// @param TensorRow row - image & label read into this tensor row.
// @return Status - The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
// Get needed files in the folder_path_.
// @return Status - The status code returned.
Status WalkAllFiles() override;
// Read images and labels from the file stream.
// @param std::ifstream *image_reader - image file stream.
// @param std::ifstream *label_reader - label file stream.
// @param size_t index - the index of file that is reading.
// @return Status The status code returned.
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) override;
// Check label stream.
// @param const std::string &file_name - label file name.
// @param std::ifstream *label_reader - label file stream.
// @param uint32_t num_labels - returns the number of labels.
// @return Status The status code returned.
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) override;
const bool compat_; // compatible with mnist
std::vector<QMnistImageInfoPair> image_info_pairs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_

View File

@ -89,6 +89,7 @@ constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kQMnistNode[] = "QMnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kSBUNode[] = "SBUDataset";
constexpr char kTextFileNode[] = "TextFileDataset";

View File

@ -17,6 +17,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
manifest_node.cc
minddata_node.cc
mnist_node.cc
qmnist_node.cc
random_node.cc
sbu_node.cc
text_file_node.cc

View File

@ -0,0 +1,150 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/qmnist_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
compat_(compat),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> QMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<QMnistNode>(dataset_dir_, usage_, compat_, sampler, cache_);
return node;
}
void QMnistNode::Print(std::ostream &out) const {
out << (Name() + "(dataset dir: " + dataset_dir_ + ", usage: " + usage_ +
", compat: " + (compat_ ? "true" : "false") + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status QMnistNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistNode", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("QMnistNode", usage_, {"train", "test", "test10k", "test50k", "nist", "all"}));
return Status::OK();
}
Status QMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
if (compat_) {
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
} else {
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
}
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt),
num_workers_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status QMnistNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status QMnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(QMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status QMnistNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["compat"] = compat_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("compat") != json_obj.end(), "Failed to find compat");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
bool compat = json_obj["compat"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<QMnistNode>(dataset_dir, usage, compat, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class QMnistNode : public MappableSourceNode {
public:
/// \brief Constructor.
QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~QMnistNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kQMnistNode; }
/// \brief Print the description.
/// \param out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Getter functions.
const std::string &Usage() const { return usage_; }
/// \brief Getter functions.
const bool Compat() const { return compat_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
bool compat_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_

View File

@ -2273,6 +2273,90 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \class QMnistDataset
/// \brief A source dataset that reads and parses QMNIST dataset.
class QMnistDataset : public Dataset {
public:
/// \brief Constructor of QMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
/// \param[in] compat Whether the label for each example is class number (compat=true)
/// or the full QMNIST information (compat=false).
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of QMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
/// \param[in] compat Whether the label for each example is class number (compat=true)
/// or the full QMNIST information (compat=false).
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of QMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
/// \param[in] compat Whether the label for each example is class number (compat=true)
/// or the full QMNIST information (compat=false).
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
/// Destructor of QMnistDataset.
~QMnistDataset() = default;
};
/// \brief Function to create a QMnistDataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all" (default = "all").
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information
/// (default = true).
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the QMnistDataset.
inline std::shared_ptr<QMnistDataset> QMnist(
const std::string &dataset_dir, const std::string &usage = "all", bool compat = true,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
/// \brief Function to create a QMnistDataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the QMnistDataset.
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage, bool compat,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
/// \brief Function to create a QMnistDataset.
/// \note The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all".
/// \param[in] compat Whether the label for each example is class number or the full QMNIST information.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the QMnistDataset.
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage, bool compat,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
/// \brief Function to create a ConcatDataset.
/// \note Reload "+" operator to concat two datasets.
/// \param[in] datasets1 Shared pointer to the first dataset to be concatenated.
@ -2565,15 +2649,14 @@ class USPSDataset : public Dataset {
public:
/// \brief Constructor of USPSDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0 means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// \param[in] usage Usage of USPS, can be "train", "test" or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).

View File

@ -44,6 +44,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;
friend class QMnistDataset;
friend class RandomDataDataset;
friend class SBUDataset;
friend class TextFileDataset;

View File

@ -66,7 +66,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset
check_sbu_dataset, check_qmnist_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -3441,6 +3441,131 @@ class MnistDataset(MappableDataset):
return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
class QMnistDataset(MappableDataset):
"""
A source dataset for reading and parsing the QMNIST dataset.
The generated dataset has two columns :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is a scalar when `compat` is True else a tensor both of the uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test`, `test10k`, `test50k`, `nist`
or `all` (default=None, will read all samples).
compat (bool, optional): Whether the label for each example is class number (compat=True) or the full QMNIST
information (compat=False) (default=True).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will read all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, will use value set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> qmnist_dataset_dir = "/path/to/qmnist_dataset_directory"
>>>
>>> # Read 3 samples from QMNIST train dataset
>>> dataset = ds.QMnistDataset(dataset_dir=qmnist_dataset_dir, num_samples=3)
>>>
>>> # Note: In QMNIST dataset, each dictionary has keys "image" and "label"
About QMNIST dataset:
The QMNIST dataset was generated from the original data found in the NIST Special Database 19 with the goal to
match the MNIST preprocessing as closely as possible.
Through an iterative process, researchers tried to generate an additional 50k images of MNIST-like data.
They started with a reconstruction process given in the paper and used the Hungarian algorithm to find the best
matches between the original MNIST samples and their reconstructed samples.
Here is the original QMNIST dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
qmnist_dataset_dir
qmnist-train-images-idx3-ubyte
qmnist-train-labels-idx2-int
qmnist-test-images-idx3-ubyte
qmnist-test-labels-idx2-int
xnist-images-idx3-ubyte
xnist-labels-idx2-int
Citation:
.. code-block::
@incollection{qmnist-2019,
title = "Cold Case: The Lost MNIST Digits",
author = "Chhavi Yadav and L\'{e}on Bottou",\
booktitle = {Advances in Neural Information Processing Systems 32},
year = {2019},
publisher = {Curran Associates, Inc.},
}
"""
@check_qmnist_dataset
def __init__(self, dataset_dir, usage=None, compat=True, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
self.compat = compat
def parse(self, children=None):
return cde.QMnistNode(self.dataset_dir, self.usage, self.compat, self.sampler)
class MindDataset(MappableDataset):
"""
A source dataset for reading and parsing MindRecord dataset.

View File

@ -92,6 +92,36 @@ def check_mnist_cifar_dataset(method):
return new_method
def check_qmnist_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'compat']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_manifestdataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""

View File

@ -28,6 +28,7 @@ SET(DE_UT_SRCS
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_qmnist_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_save.cc
c_api_dataset_sbu_test.cc

View File

@ -0,0 +1,343 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestQMnistTrainDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDataset.";
// Create a QMNIST Train Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestQMnistTestDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTestDataset.";
// Create a QMNIST Test Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "test", true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestQMnistNistDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistNistDataset.";
// Create a QMNIST Nist Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "nist", true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestQMnistAllDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistAllDataset.";
// Create a QMNIST All Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "all", true, std::make_shared<RandomSampler>(false, 20));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestQMnistCompatDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistCompatDataset.";
// Create a QMNIST All Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "all", false, std::make_shared<RandomSampler>(false, 20));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
MS_LOG(INFO) << "Tensor label shape: " << label.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestQMnistDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetWithPipeline.";
// Create two QMNIST Train Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds1 = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
std::shared_ptr<Dataset> ds2 = QMnist(folder_path, "train", true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestGetQMnistDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetQMnistTrainDatasetSize.";
// Create a QMNIST Train Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestQMnistDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetGetters.";
// Create a QMNIST Train Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<28,28,1>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestQMnistDataFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataFail.";
// Create a QMNIST Dataset
std::shared_ptr<Dataset> ds = QMnist("", "train", true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid QMNIST input
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestQMnistDataWithInvalidUsageFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithInvalidUsageFail.";
// Create a QMNIST Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "validation", true);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid QMNIST input, validation is not a valid usage
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestQMnistDataWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithNullSamplerFail.";
// Create a QMNIST Dataset
std::string folder_path = datasets_root_path_ + "/testQMnistData/";
std::shared_ptr<Dataset> ds = QMnist(folder_path, "train", true, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid QMNIST input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,343 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test QMnistDataset operator
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testQMnistData"
def load_qmnist(path, usage, compat=True):
"""
load QMNIST data
"""
image_path = []
label_path = []
image_ext = "images-idx3-ubyte"
label_ext = "labels-idx2-int"
train_prefix = "qmnist-train"
test_prefix = "qmnist-test"
nist_prefix = "xnist"
assert usage in ["train", "test", "nist", "all"]
if usage == "train":
image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
elif usage == "test":
image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
elif usage == "nist":
image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
elif usage == "all":
image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
assert len(image_path) == len(label_path)
images = []
labels = []
for i, _ in enumerate(image_path):
with open(image_path[i], 'rb') as image_file:
image_file.read(16)
image = np.fromfile(image_file, dtype=np.uint8)
image = image.reshape(-1, 28, 28, 1)
image[image > 0] = 255 # Perform binarization to maintain consistency with our API
images.append(image)
with open(label_path[i], 'rb') as label_file:
label_file.read(12)
label = np.fromfile(label_file, dtype='>u4')
label = label.reshape(-1, 8)
labels.append(label)
images = np.concatenate(images, 0)
labels = np.concatenate(labels, 0)
if compat:
return images, labels[:, 0]
return images, labels
def visualize_dataset(images, labels):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
plt.title(labels[i])
plt.show()
def test_qmnist_content_check():
"""
Validate QMnistDataset image readings
"""
logger.info("Test QMnistDataset Op with content check")
for usage in ["train", "test", "nist", "all"]:
data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False)
images, labels = load_qmnist(DATA_DIR, usage, True)
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
image_list, label_list = [], []
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_list.append(data["image"])
label_list.append("label {}".format(data["label"]))
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 10
for usage in ["train", "test", "nist", "all"]:
data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False)
images, labels = load_qmnist(DATA_DIR, usage, False)
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
image_list, label_list = [], []
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_list.append(data["image"])
label_list.append("label {}".format(data["label"]))
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 10
def test_qmnist_basic():
"""
Validate QMnistDataset
"""
logger.info("Test QMnistDataset Op")
# case 1: test loading whole dataset
data1 = ds.QMnistDataset(DATA_DIR, "train", True)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 10
# case 2: test num_samples
data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5)
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 5
# case 3: test repeat
data3 = ds.QMnistDataset(DATA_DIR, "train", True)
data3 = data3.repeat(5)
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 50
# case 4: test batch with drop_remainder=False
data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
assert data4.get_dataset_size() == 10
assert data4.get_batch_size() == 1
data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
assert data4.get_dataset_size() == 2
assert data4.get_batch_size() == 7
num_iter4 = 0
for _ in data4.create_dict_iterator(num_epochs=1):
num_iter4 += 1
assert num_iter4 == 2
# case 5: test batch with drop_remainder=True
data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
assert data5.get_dataset_size() == 10
assert data5.get_batch_size() == 1
data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
assert data5.get_dataset_size() == 3
assert data5.get_batch_size() == 3
num_iter5 = 0
for _ in data5.create_dict_iterator(num_epochs=1):
num_iter5 += 1
assert num_iter5 == 3
# case 6: test get_col_names
dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
assert dataset.get_col_names() == ["image", "label"]
def test_qmnist_pk_sampler():
"""
Test QMnistDataset with PKSampler
"""
logger.info("Test QMnistDataset Op with PKSampler")
golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
sampler = ds.PKSampler(10)
data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler)
num_iter = 0
label_list = []
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
label_list.append(item["label"])
num_iter += 1
np.testing.assert_array_equal(golden, label_list)
assert num_iter == 10
def test_qmnist_sequential_sampler():
"""
Test QMnistDataset with SequentialSampler
"""
logger.info("Test QMnistDataset Op with SequentialSampler")
num_samples = 10
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler)
data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples)
label_list1, label_list2 = [], []
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
label_list1.append(item1["label"].asnumpy())
label_list2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list1, label_list2)
assert num_iter == num_samples
def test_qmnist_exception():
"""
Test error cases for QMnistDataset
"""
logger.info("Test error cases for MnistDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
error_msg_8 = "The corresponding data files"
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.QMnistDataset(DATA_DIR, "train", True)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.QMnistDataset(DATA_DIR, "train", True)
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.QMnistDataset(DATA_DIR, "train", True)
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
for _ in data.__iter__():
pass
def test_qmnist_visualize(plot=False):
"""
Visualize QMnistDataset results
"""
logger.info("Test QMnistDataset visualization")
data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (28, 28, 1)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 10
if plot:
visualize_dataset(image_list, label_list)
def test_qmnist_usage():
"""
Validate QMnistDataset image readings
"""
logger.info("Test QMnistDataset usage flag")
def test_config(usage, path=None):
path = DATA_DIR if path is None else path
try:
data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("train") == 10
assert test_config("test") == 10
assert test_config("nist") == 10
assert test_config("all") == 30
assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in\
test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
if __name__ == '__main__':
test_qmnist_content_check()
test_qmnist_basic()
test_qmnist_pk_sampler()
test_qmnist_sequential_sampler()
test_qmnist_exception()
test_qmnist_visualize(plot=True)
test_qmnist_usage()