diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index a0609594db0..49d9fee5808 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -102,6 +102,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h" #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" @@ -1220,6 +1221,28 @@ MnistDataset::MnistDataset(const std::vector &dataset_dir, const std::vect } #ifndef ENABLE_ANDROID +QMnistDataset::QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const std::shared_ptr &sampler, const std::shared_ptr &cache) { + auto sampler_obj = sampler ? sampler->Parse() : nullptr; + auto ds = std::make_shared(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache); + ir_node_ = std::static_pointer_cast(ds); +} + +QMnistDataset::QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const Sampler *sampler, const std::shared_ptr &cache) { + auto sampler_obj = sampler ? sampler->Parse() : nullptr; + auto ds = std::make_shared(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache); + ir_node_ = std::static_pointer_cast(ds); +} + +QMnistDataset::QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const std::reference_wrapper sampler, + const std::shared_ptr &cache) { + auto sampler_obj = sampler.get().Parse(); + auto ds = std::make_shared(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache); + ir_node_ = std::static_pointer_cast(ds); +} + TextFileDataset::TextFileDataset(const std::vector> &dataset_files, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache) { diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc index d3cd755ac41..6ddefe7ef77 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc @@ -44,6 +44,7 @@ #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h" #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" @@ -248,6 +249,17 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { })); })); +PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) { + (void)py::class_>(*m, "QMnistNode", + "to create an QMnistNode") + .def(py::init([](std::string dataset_dir, std::string usage, bool compat, py::handle sampler) { + auto qmnist = + std::make_shared(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr); + THROW_IF_ERROR(qmnist->ValidateParams()); + return qmnist; + })); + })); + PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "RandomNode", "to create a RandomNode") diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt index 5bfc24eaf2f..07ad602463c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt @@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES cityscapes_op.cc div2k_op.cc flickr_op.cc + qmnist_op.cc ) set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index d038274ab44..dcf7ec82404 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -81,7 +81,7 @@ class MnistOp : public MappableLeafOp { // \return DatasetName of the current Op virtual std::string DatasetName(bool upper = false) const { return upper ? "Mnist" : "mnist"; } - private: + protected: // Load a tensor row according to a pair // @param row_id_type row_id - id for this tensor row // @param ImageLabelPair pair - @@ -94,14 +94,14 @@ class MnistOp : public MappableLeafOp { // @param std::ifstream *image_reader - image file stream // @param uint32_t num_images - returns the number of images // @return Status The status code returned - Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); + virtual Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); // Check label stream. // @param const std::string &file_name - label file name // @param std::ifstream *label_reader - label file stream // @param uint32_t num_labels - returns the number of labels // @return Status The status code returned - Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); + virtual Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); // Read 4 bytes of data from a file stream. // @param std::ifstream *reader - file stream to read @@ -118,7 +118,7 @@ class MnistOp : public MappableLeafOp { // @param std::ifstream *label_reader - label file stream // @param int64_t read_num - number of image to read // @return Status The status code returned - Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); + virtual Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); // Parse all mnist dataset files // @return Status The status code returned @@ -126,7 +126,7 @@ class MnistOp : public MappableLeafOp { // Read all files in the directory // @return Status The status code returned - Status WalkAllFiles(); + virtual Status WalkAllFiles(); // Called first when function is called // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc new file mode 100644 index 00000000000..7ada042496a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc @@ -0,0 +1,283 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/qmnist_op.h" + +#include +#include +#include +#include +#include + +#include "debug/common.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/file_utils.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace dataset { +const int32_t kQMnistLabelFileMagicNumber = 3074; +const int32_t kQMnistImageRows = 28; +const int32_t kQMnistImageCols = 28; +const int32_t kQMnistLabelLength = 8; + +QMnistOp::QMnistOp(const std::string &folder_path, const std::string &usage, bool compat, + std::unique_ptr data_schema, std::shared_ptr sampler, int32_t num_workers, + int32_t queue_size) + : MnistOp(usage, num_workers, folder_path, queue_size, std::move(data_schema), std::move(sampler)), + compat_(compat) {} + +void QMnistOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows: " << num_rows_ << "\n" + << DatasetName(true) << " directory: " << folder_path_ << "\nUsage: " << usage_ + << "\nCompat: " << (compat_ ? "yes" : "no") << "\n\n"; + } +} + +// Load 1 TensorRow (image, label) using 1 MnistLabelPair or QMnistImageInfoPair. +Status QMnistOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { + RETURN_UNEXPECTED_IF_NULL(trow); + std::shared_ptr image, label; + if (compat_) { + MnistLabelPair qmnist_pair = image_label_pairs_[row_id]; + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(qmnist_pair.second, &label)); + } else { + QMnistImageInfoPair qmnist_pair = image_info_pairs_[row_id]; + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.first, &image)); + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(qmnist_pair.second, &label)); + } + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + trow->setPath({image_path_[row_id], label_path_[row_id]}); + return Status::OK(); +} + +Status QMnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { + RETURN_UNEXPECTED_IF_NULL(count); + *count = 0; + + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + const int64_t num_samples = 0; + const int64_t start_index = 0; + auto sampler = std::make_shared(start_index, num_samples); + + std::shared_ptr cfg = GlobalContext::config_manager(); + int32_t num_workers = cfg->num_parallel_workers(); + int32_t op_connector_size = cfg->op_connector_size(); + + // compat does not affect the count result, so set it to true default. + auto op = + std::make_shared(dir, usage, true, std::move(schema), std::move(sampler), num_workers, op_connector_size); + + // the logic of counting the number of samples + RETURN_IF_NOT_OK(op->WalkAllFiles()); + for (size_t i = 0; i < op->image_names_.size(); ++i) { + std::ifstream image_reader; + image_reader.open(op->image_names_[i], std::ios::binary); + std::ifstream label_reader; + label_reader.open(op->label_names_[i], std::ios::binary); + + uint32_t num_images; + RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images)); + uint32_t num_labels; + RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), + "Invalid data, num of images is not equal to num of labels."); + + if (usage == "test10k") { + // only use the first 10k samples and drop the last 50k samples + num_images = 10000; + num_labels = 10000; + } else if (usage == "test50k") { + // only use the last 50k samples and drop the first 10k samples + num_images = 50000; + num_labels = 50000; + } + + *count = *count + num_images; + + // Close the readers + image_reader.close(); + label_reader.close(); + } + + return Status::OK(); +} + +Status QMnistOp::WalkAllFiles() { + const std::string image_ext = "images-idx3-ubyte"; + const std::string label_ext = "labels-idx2-int"; + const std::string train_prefix = "qmnist-train"; + const std::string test_prefix = "qmnist-test"; + const std::string nist_prefix = "xnist"; + + auto real_folder_path = FileUtils::GetRealPath(folder_path_.data()); + CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_); + Path root_dir(real_folder_path.value()); + + if (usage_ == "train") { + image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString()); + } else if (usage_ == "test" || usage_ == "test10k" || usage_ == "test50k") { + image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString()); + } else if (usage_ == "nist") { + image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString()); + } else if (usage_ == "all") { + image_names_.push_back((root_dir / Path(train_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(train_prefix + "-" + label_ext)).ToString()); + image_names_.push_back((root_dir / Path(test_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(test_prefix + "-" + label_ext)).ToString()); + image_names_.push_back((root_dir / Path(nist_prefix + "-" + image_ext)).ToString()); + label_names_.push_back((root_dir / Path(nist_prefix + "-" + label_ext)).ToString()); + } + + CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(), + "Invalid data, num of images does not equal to num of labels."); + + for (size_t i = 0; i < image_names_.size(); i++) { + Path file_path(image_names_[i]); + CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), + "Failed to find " + DatasetName() + " image file: " + file_path.ToString()); + MS_LOG(INFO) << DatasetName(true) << " operator found image file at " << file_path.ToString() << "."; + } + + for (size_t i = 0; i < label_names_.size(); i++) { + Path file_path(label_names_[i]); + CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), + "Failed to find " + DatasetName() + " label file: " + file_path.ToString()); + MS_LOG(INFO) << DatasetName(true) << " operator found label file at " << file_path.ToString() << "."; + } + + return Status::OK(); +} + +Status QMnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) { + RETURN_UNEXPECTED_IF_NULL(image_reader); + RETURN_UNEXPECTED_IF_NULL(label_reader); + uint32_t num_images, num_labels; + RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images)); + RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), + "Invalid data, num_images is not equal to num_labels. Ensure data file is not damaged."); + + // The image size of the QMNIST dataset is fixed at [28,28] + int64_t image_size = kQMnistImageRows * kQMnistImageCols; + int64_t label_length = kQMnistLabelLength; + + if (usage_ == "test10k") { + // only use the first 10k samples and drop the last 50k samples + num_images = 10000; + num_labels = 10000; + } else if (usage_ == "test50k") { + num_images = 50000; + num_labels = 50000; + // skip the first 10k samples for ifstream reader + (void)image_reader->ignore(image_size * 10000); + (void)label_reader->ignore(label_length * 10000 * 4); + } + + auto images_buf = std::make_unique(image_size * num_images); + auto labels_buf = std::make_unique(label_length * num_labels); + if (images_buf == nullptr || labels_buf == nullptr) { + std::string err_msg = "[Internal ERROR] Failed to allocate memory for " + DatasetName() + " buffer."; + MS_LOG(ERROR) << err_msg.c_str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)image_reader->read(images_buf.get(), image_size * num_images); + if (image_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " image: " + image_names_[index] + + ", size:" + std::to_string(image_size * num_images) + + ". Ensure data file is not damaged."); + } + // uint32_t use 4 bytes in memory + (void)label_reader->read(reinterpret_cast(labels_buf.get()), label_length * num_labels * 4); + if (label_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " label:" + label_names_[index] + + ", size: " + std::to_string(label_length * num_labels) + + ". Ensure data file is not damaged."); + } + TensorShape image_tensor_shape = TensorShape({kQMnistImageRows, kQMnistImageCols, 1}); + TensorShape label_tensor_shape = TensorShape({kQMnistLabelLength}); + for (int64_t data_index = 0; data_index != num_images; data_index++) { + auto image = &images_buf[data_index * image_size]; + for (int64_t image_index = 0; image_index < image_size; image_index++) { + image[image_index] = (image[image_index] == 0) ? 0 : 255; + } + std::shared_ptr image_tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(), + reinterpret_cast(image), &image_tensor)); + + auto label = &labels_buf[data_index * label_length]; + for (int64_t label_index = 0; label_index < label_length; label_index++) { + label[label_index] = SwapEndian(label[label_index]); + } + std::shared_ptr label_tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(label_tensor_shape, data_schema_->Column(1).Type(), + reinterpret_cast(label), &label_tensor)); + + image_info_pairs_.emplace_back(std::make_pair(image_tensor, label_tensor)); + image_label_pairs_.emplace_back(std::make_pair(image_tensor, label[0])); + image_path_.push_back(image_names_[index]); + label_path_.push_back(label_names_[index]); + } + return Status::OK(); +} + +Status QMnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { + RETURN_UNEXPECTED_IF_NULL(label_reader); + RETURN_UNEXPECTED_IF_NULL(num_labels); + CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(), + "Invalid file, failed to open " + DatasetName() + " label file: " + file_name); + int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); + (void)label_reader->seekg(0, std::ios::beg); + // The first 12 bytes of the label file are type, number and length + CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 12, "Invalid file, " + DatasetName() + " file is corrupted: " + file_name); + uint32_t magic_number; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); + CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kQMnistLabelFileMagicNumber, + "Invalid file, this is not the " + DatasetName() + " label file: " + file_name); + uint32_t num_items; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); + uint32_t length; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &length)); + CHECK_FAIL_RETURN_UNEXPECTED(length == kQMnistLabelLength, "Invalid data, length of labels is not equal to 8."); + + CHECK_FAIL_RETURN_UNEXPECTED((label_len - 12) == num_items * kQMnistLabelLength * 4, + "Invalid data, number of labels is wrong."); + *num_labels = num_items; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h new file mode 100644 index 00000000000..2e0b8f646dc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h @@ -0,0 +1,113 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { + +using QMnistImageInfoPair = std::pair, std::shared_ptr>; + +class QMnistOp : public MnistOp { + public: + // Constructor. + // @param const std::string &folder_path - dir directory of QMNIST data file. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or + // 'all'. + // @param bool compat - Compatibility with Mnist. + // @param std::unique_ptr data_schema - the schema of the QMNIST dataset. + // @param td::unique_ptr sampler - sampler tells QMnistOp what to read. + // @param int32_t num_workers - number of workers reading images in parallel. + // @param int32_t queue_size - connector queue size. + QMnistOp(const std::string &folder_path, const std::string &usage, bool compat, + std::unique_ptr data_schema, std::shared_ptr sampler, int32_t num_workers, + int32_t queue_size); + + // Destructor. + ~QMnistOp() = default; + + // Op name getter. + // @return std::string - Name of the current Op. + std::string Name() const override { return "QMnistOp"; } + + // DatasetName name getter + // \return std::string - DatasetName of the current Op + std::string DatasetName(bool upper = false) const { return upper ? "QMnist" : "qmnist"; } + + // A print method typically used for debugging. + // @param std::ostream &out - out stream. + // @param bool show_all - whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the QMNIST dataset. + // @param const std::string &dir - path to the QMNIST directory. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or + // 'all'. + // @param int64_t *count - output arg that will hold the actual dataset size. + // @return Status -The status code returned. + static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); + + private: + // Load a tensor row according to a pair. + // @param row_id_type row_id - id for this tensor row. + // @param TensorRow row - image & label read into this tensor row. + // @return Status - The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; + + // Get needed files in the folder_path_. + // @return Status - The status code returned. + Status WalkAllFiles() override; + + // Read images and labels from the file stream. + // @param std::ifstream *image_reader - image file stream. + // @param std::ifstream *label_reader - label file stream. + // @param size_t index - the index of file that is reading. + // @return Status The status code returned. + Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) override; + + // Check label stream. + // @param const std::string &file_name - label file name. + // @param std::ifstream *label_reader - label file stream. + // @param uint32_t num_labels - returns the number of labels. + // @return Status The status code returned. + Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) override; + + const bool compat_; // compatible with mnist + + std::vector image_info_pairs_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 60bece025d8..4c328114aa2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -89,6 +89,7 @@ constexpr char kImageFolderNode[] = "ImageFolderDataset"; constexpr char kManifestNode[] = "ManifestDataset"; constexpr char kMindDataNode[] = "MindDataDataset"; constexpr char kMnistNode[] = "MnistDataset"; +constexpr char kQMnistNode[] = "QMnistDataset"; constexpr char kRandomNode[] = "RandomDataset"; constexpr char kSBUNode[] = "SBUDataset"; constexpr char kTextFileNode[] = "TextFileDataset"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt index 6f5b3b2209b..6144e87b540 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt @@ -17,6 +17,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES manifest_node.cc minddata_node.cc mnist_node.cc + qmnist_node.cc random_node.cc sbu_node.cc text_file_node.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc new file mode 100644 index 00000000000..9879783c5c3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/qmnist_op.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/serdes.h" +#endif +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, + std::shared_ptr sampler, std::shared_ptr cache) + : MappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + usage_(usage), + compat_(compat), + sampler_(sampler) {} + +std::shared_ptr QMnistNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, usage_, compat_, sampler, cache_); + return node; +} + +void QMnistNode::Print(std::ostream &out) const { + out << (Name() + "(dataset dir: " + dataset_dir_ + ", usage: " + usage_ + + ", compat: " + (compat_ ? "true" : "false") + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); +} + +Status QMnistNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistNode", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistNode", sampler_)); + RETURN_IF_NOT_OK(ValidateStringValue("QMnistNode", usage_, {"train", "test", "test10k", "test50k", "nist", "all"})); + return Status::OK(); +} + +Status QMnistNode::Build(std::vector> *const node_ops) { + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + if (compat_) { + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + } else { + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + } + + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + + auto op = std::make_shared(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt), + num_workers_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + + return Status::OK(); +} + +// Get the shard id of node +Status QMnistNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + +// Get Dataset size +Status QMnistNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(QMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status QMnistNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["compat"] = compat_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), + "Failed to find num_parallel_workers"); + CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); + CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); + CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("compat") != json_obj.end(), "Failed to find compat"); + CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); + std::string dataset_dir = json_obj["dataset_dir"]; + std::string usage = json_obj["usage"]; + bool compat = json_obj["compat"]; + std::shared_ptr sampler; + RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); + std::shared_ptr cache = nullptr; + RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); + *ds = std::make_shared(dataset_dir, usage, compat, sampler, cache); + (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); + return Status::OK(); +} +#endif +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h new file mode 100644 index 00000000000..f70fe969dc6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class QMnistNode : public MappableSourceNode { + public: + /// \brief Constructor. + QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor. + ~QMnistNode() = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kQMnistNode; } + + /// \brief Print the description. + /// \param out - The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + + /// \brief Getter functions. + const std::string &Usage() const { return usage_; } + + /// \brief Getter functions. + const bool Compat() const { return compat_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + +#ifndef ENABLE_ANDROID + /// \brief Function to read dataset in json + /// \param[in] json_obj The JSON object to be deserialized + /// \param[out] ds Deserialized dataset + /// \return Status The status code returned + static Status from_json(nlohmann::json json_obj, std::shared_ptr *ds); +#endif + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + std::string usage_; + bool compat_; + std::shared_ptr sampler_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h index b8fe7cab656..d632c626142 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h @@ -2273,6 +2273,90 @@ inline std::shared_ptr Mnist(const std::string &dataset_dir, const return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); } +/// \class QMnistDataset +/// \brief A source dataset that reads and parses QMNIST dataset. +class QMnistDataset : public Dataset { + public: + /// \brief Constructor of QMnistDataset. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all". + /// \param[in] compat Whether the label for each example is class number (compat=true) + /// or the full QMNIST information (compat=false). + /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not + /// given, a `RandomSampler` will be used to randomly iterate the entire dataset. + /// \param[in] cache Tensor cache to use. + explicit QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const std::shared_ptr &sampler, const std::shared_ptr &cache); + + /// \brief Constructor of QMnistDataset. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all". + /// \param[in] compat Whether the label for each example is class number (compat=true) + /// or the full QMNIST information (compat=false). + /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. + /// \param[in] cache Tensor cache to use. + explicit QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const Sampler *sampler, const std::shared_ptr &cache); + + /// \brief Constructor of QMnistDataset. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all". + /// \param[in] compat Whether the label for each example is class number (compat=true) + /// or the full QMNIST information (compat=false). + /// \param[in] sampler Sampler object used to choose samples from the dataset. + /// \param[in] cache Tensor cache to use. + explicit QMnistDataset(const std::vector &dataset_dir, const std::vector &usage, bool compat, + const std::reference_wrapper sampler, const std::shared_ptr &cache); + + /// Destructor of QMnistDataset. + ~QMnistDataset() = default; +}; + +/// \brief Function to create a QMnistDataset. +/// \note The generated dataset has two columns ["image", "label"]. +/// \param[in] dataset_dir Path to the root directory that contains the dataset. +/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all" (default = "all"). +/// \param[in] compat Whether the label for each example is class number or the full QMNIST information +/// (default = true). +/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not +/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()). +/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). +/// \return Shared pointer to the QMnistDataset. +inline std::shared_ptr QMnist( + const std::string &dataset_dir, const std::string &usage = "all", bool compat = true, + const std::shared_ptr &sampler = std::make_shared(), + const std::shared_ptr &cache = nullptr) { + return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache); +} + +/// \brief Function to create a QMnistDataset. +/// \note The generated dataset has two columns ["image", "label"]. +/// \param[in] dataset_dir Path to the root directory that contains the dataset. +/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all". +/// \param[in] compat Whether the label for each example is class number or the full QMNIST information. +/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. +/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). +/// \return Shared pointer to the QMnistDataset. +inline std::shared_ptr QMnist(const std::string &dataset_dir, const std::string &usage, bool compat, + const Sampler *sampler, + const std::shared_ptr &cache = nullptr) { + return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache); +} + +/// \brief Function to create a QMnistDataset. +/// \note The generated dataset has two columns ["image", "label"]. +/// \param[in] dataset_dir Path to the root directory that contains the dataset. +/// \param[in] usage Usage of QMNIST, can be "train", "test", "test10k", "test50k", "nist" or "all". +/// \param[in] compat Whether the label for each example is class number or the full QMNIST information. +/// \param[in] sampler Sampler object used to choose samples from the dataset. +/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). +/// \return Shared pointer to the QMnistDataset. +inline std::shared_ptr QMnist(const std::string &dataset_dir, const std::string &usage, bool compat, + const std::reference_wrapper sampler, + const std::shared_ptr &cache = nullptr) { + return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache); +} + /// \brief Function to create a ConcatDataset. /// \note Reload "+" operator to concat two datasets. /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated. @@ -2565,15 +2649,14 @@ class USPSDataset : public Dataset { public: /// \brief Constructor of USPSDataset. /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all"). - /// \param[in] num_samples The number of samples to be included in the dataset - /// (Default = 0 means all samples). - /// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal). + /// \param[in] usage Usage of USPS, can be "train", "test" or "all". + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] shuffle The mode for shuffling data every epoch. /// Can be any of: /// ShuffleMode.kFalse - No shuffling is performed. /// ShuffleMode.kFiles - Shuffle files only. /// ShuffleMode.kGlobal - Shuffle both the files and samples. - /// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1). + /// \param[in] num_shards Number of shards that the dataset should be divided into. /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified (Default = 0). /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h b/mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h index 59898d4d4f2..902307a8964 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h @@ -44,6 +44,7 @@ class Sampler : std::enable_shared_from_this { friend class ManifestDataset; friend class MindDataDataset; friend class MnistDataset; + friend class QMnistDataset; friend class RandomDataDataset; friend class SBUDataset; friend class TextFileDataset; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index bdaf7c3cc0a..038b6d0bdc7 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -66,7 +66,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \ check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \ check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \ - check_sbu_dataset + check_sbu_dataset, check_qmnist_dataset from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ get_prefetch_size from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -3436,6 +3436,131 @@ class MnistDataset(MappableDataset): return cde.MnistNode(self.dataset_dir, self.usage, self.sampler) +class QMnistDataset(MappableDataset): + """ + A source dataset for reading and parsing the QMNIST dataset. + + The generated dataset has two columns :py:obj:`[image, label]`. + The tensor of column :py:obj:`image` is of the uint8 type. + The tensor of column :py:obj:`label` is a scalar when `compat` is True else a tensor both of the uint32 type. + + Args: + dataset_dir (str): Path to the root directory that contains the dataset. + usage (str, optional): Usage of this dataset, can be `train`, `test`, `test10k`, `test50k`, `nist` + or `all` (default=None, will read all samples). + compat (bool, optional): Whether the label for each example is class number (compat=True) or the full QMNIST + information (compat=False) (default=True). + num_samples (int, optional): The number of images to be included in the dataset + (default=None, will read all images). + num_parallel_workers (int, optional): Number of workers to read the data + (default=None, will use value set in the config). + shuffle (bool, optional): Whether or not to perform shuffle on the dataset + (default=None, expected order behavior shown in the table). + sampler (Sampler, optional): Object used to choose samples from the + dataset (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). + When this argument is specified, `num_samples` reflects the maximum sample number of per shard. + shard_id (int, optional): The shard ID within `num_shards` (default=None). This + argument can only be specified when `num_shards` is also specified. + cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. + (default=None, which means no cache is used). + + Raises: + RuntimeError: If dataset_dir does not contain data files. + RuntimeError: If num_parallel_workers exceeds the max thread numbers. + RuntimeError: If sampler and shuffle are specified at the same time. + RuntimeError: If sampler and sharding are specified at the same time. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + ValueError: If shard_id is invalid (< 0 or >= num_shards). + + Note: + - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. + The table below shows what input arguments are allowed and their expected behavior. + + .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` + :widths: 25 25 50 + :header-rows: 1 + + * - Parameter `sampler` + - Parameter `shuffle` + - Expected Order Behavior + * - None + - None + - random order + * - None + - True + - random order + * - None + - False + - sequential order + * - Sampler object + - None + - order defined by sampler + * - Sampler object + - True + - not allowed + * - Sampler object + - False + - not allowed + + Examples: + >>> qmnist_dataset_dir = "/path/to/qmnist_dataset_directory" + >>> + >>> # Read 3 samples from QMNIST train dataset + >>> dataset = ds.QMnistDataset(dataset_dir=qmnist_dataset_dir, num_samples=3) + >>> + >>> # Note: In QMNIST dataset, each dictionary has keys "image" and "label" + + About QMNIST dataset: + + The QMNIST dataset was generated from the original data found in the NIST Special Database 19 with the goal to + match the MNIST preprocessing as closely as possible. + Through an iterative process, researchers tried to generate an additional 50k images of MNIST-like data. + They started with a reconstruction process given in the paper and used the Hungarian algorithm to find the best + matches between the original MNIST samples and their reconstructed samples. + + Here is the original QMNIST dataset structure. + You can unzip the dataset files into this directory structure and read by MindSpore's API. + + .. code-block:: + + . + └── qmnist_dataset_dir + ├── qmnist-train-images-idx3-ubyte + ├── qmnist-train-labels-idx2-int + ├── qmnist-test-images-idx3-ubyte + ├── qmnist-test-labels-idx2-int + ├── xnist-images-idx3-ubyte + └── xnist-labels-idx2-int + + Citation: + + .. code-block:: + + @incollection{qmnist-2019, + title = "Cold Case: The Lost MNIST Digits", + author = "Chhavi Yadav and L\'{e}on Bottou",\ + booktitle = {Advances in Neural Information Processing Systems 32}, + year = {2019}, + publisher = {Curran Associates, Inc.}, + } + """ + + @check_qmnist_dataset + def __init__(self, dataset_dir, usage=None, compat=True, num_samples=None, num_parallel_workers=None, + shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) + + self.dataset_dir = dataset_dir + self.usage = replace_none(usage, "all") + self.compat = compat + + def parse(self, children=None): + return cde.QMnistNode(self.dataset_dir, self.usage, self.compat, self.sampler) + + class MindDataset(MappableDataset): """ A source dataset for reading and parsing MindRecord dataset. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ec0c2b2d7aa..208c5135e39 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -92,6 +92,36 @@ def check_mnist_cifar_dataset(method): return new_method +def check_qmnist_dataset(method): + """A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + nreq_param_bool = ['shuffle', 'compat'] + + dataset_dir = param_dict.get('dataset_dir') + check_dir(dataset_dir) + + usage = param_dict.get('usage') + if usage is not None: + check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage") + + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + + check_sampler_shuffle_shard_options(param_dict) + + cache = param_dict.get('cache') + check_cache_option(cache) + + return method(self, *args, **kwargs) + + return new_method + + def check_manifestdataset(method): """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset).""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index be6bf54ab35..7d93ecc2c86 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -28,6 +28,7 @@ SET(DE_UT_SRCS c_api_dataset_manifest_test.cc c_api_dataset_minddata_test.cc c_api_dataset_ops_test.cc + c_api_dataset_qmnist_test.cc c_api_dataset_randomdata_test.cc c_api_dataset_save.cc c_api_dataset_sbu_test.cc diff --git a/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc b/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc new file mode 100644 index 00000000000..6e2d79f3171 --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc @@ -0,0 +1,343 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; +using mindspore::dataset::DataType; +using mindspore::dataset::Tensor; +using mindspore::dataset::TensorShape; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +TEST_F(MindDataTestPipeline, TestQMnistTrainDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDataset."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestQMnistTestDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTestDataset."; + + // Create a QMNIST Test Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "test", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestQMnistNistDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistNistDataset."; + + // Create a QMNIST Nist Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "nist", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestQMnistAllDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistAllDataset."; + + // Create a QMNIST All Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "all", true, std::make_shared(false, 20)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestQMnistCompatDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistCompatDataset."; + + // Create a QMNIST All Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "all", false, std::make_shared(false, 20)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + MS_LOG(INFO) << "Tensor label shape: " << label.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestQMnistDatasetWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetWithPipeline."; + + // Create two QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds1 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + std::shared_ptr ds2 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds + int32_t repeat_num = 1; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 1; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create two Project operation on ds + std::vector column_project = {"image", "label"}; + ds1 = ds1->Project(column_project); + EXPECT_NE(ds1, nullptr); + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestGetQMnistDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetQMnistTrainDatasetSize."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + +TEST_F(MindDataTestPipeline, TestQMnistDatasetGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetGetters."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + std::vector column_names = {"image", "label"}; + int64_t num_classes = ds->GetNumClasses(); + EXPECT_EQ(types.size(), 2); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "uint32"); + EXPECT_EQ(shapes.size(), 2); + EXPECT_EQ(shapes[0].ToString(), "<28,28,1>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + EXPECT_EQ(num_classes, -1); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + + EXPECT_EQ(ds->GetDatasetSize(), 10); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetNumClasses(), -1); + + EXPECT_EQ(ds->GetColumnNames(), column_names); + EXPECT_EQ(ds->GetDatasetSize(), 10); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + EXPECT_EQ(ds->GetNumClasses(), -1); + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + +TEST_F(MindDataTestPipeline, TestQMnistDataFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataFail."; + + // Create a QMNIST Dataset + std::shared_ptr ds = QMnist("", "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input + EXPECT_EQ(iter, nullptr); +} + +TEST_F(MindDataTestPipeline, TestQMnistDataWithInvalidUsageFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithInvalidUsageFail."; + + // Create a QMNIST Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "validation", true); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input, validation is not a valid usage + EXPECT_EQ(iter, nullptr); +} + +TEST_F(MindDataTestPipeline, TestQMnistDataWithNullSamplerFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithNullSamplerFail."; + + // Create a QMNIST Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input, sampler cannot be nullptr + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/data/dataset/testQMnistData/qmnist-test-images-idx3-ubyte b/tests/ut/data/dataset/testQMnistData/qmnist-test-images-idx3-ubyte new file mode 100644 index 00000000000..738d84b010d Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/qmnist-test-images-idx3-ubyte differ diff --git a/tests/ut/data/dataset/testQMnistData/qmnist-test-labels-idx2-int b/tests/ut/data/dataset/testQMnistData/qmnist-test-labels-idx2-int new file mode 100644 index 00000000000..4a8c575d37f Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/qmnist-test-labels-idx2-int differ diff --git a/tests/ut/data/dataset/testQMnistData/qmnist-train-images-idx3-ubyte b/tests/ut/data/dataset/testQMnistData/qmnist-train-images-idx3-ubyte new file mode 100644 index 00000000000..738d84b010d Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/qmnist-train-images-idx3-ubyte differ diff --git a/tests/ut/data/dataset/testQMnistData/qmnist-train-labels-idx2-int b/tests/ut/data/dataset/testQMnistData/qmnist-train-labels-idx2-int new file mode 100644 index 00000000000..4a8c575d37f Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/qmnist-train-labels-idx2-int differ diff --git a/tests/ut/data/dataset/testQMnistData/xnist-images-idx3-ubyte b/tests/ut/data/dataset/testQMnistData/xnist-images-idx3-ubyte new file mode 100644 index 00000000000..738d84b010d Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/xnist-images-idx3-ubyte differ diff --git a/tests/ut/data/dataset/testQMnistData/xnist-labels-idx2-int b/tests/ut/data/dataset/testQMnistData/xnist-labels-idx2-int new file mode 100644 index 00000000000..4a8c575d37f Binary files /dev/null and b/tests/ut/data/dataset/testQMnistData/xnist-labels-idx2-int differ diff --git a/tests/ut/python/dataset/test_datasets_qmnist.py b/tests/ut/python/dataset/test_datasets_qmnist.py new file mode 100644 index 00000000000..4c65228dffc --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_qmnist.py @@ -0,0 +1,343 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test QMnistDataset operator +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as vision +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testQMnistData" + + +def load_qmnist(path, usage, compat=True): + """ + load QMNIST data + """ + image_path = [] + label_path = [] + image_ext = "images-idx3-ubyte" + label_ext = "labels-idx2-int" + train_prefix = "qmnist-train" + test_prefix = "qmnist-test" + nist_prefix = "xnist" + assert usage in ["train", "test", "nist", "all"] + if usage == "train": + image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) + elif usage == "test": + image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) + elif usage == "nist": + image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) + elif usage == "all": + image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) + image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) + image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) + + assert len(image_path) == len(label_path) + + images = [] + labels = [] + for i, _ in enumerate(image_path): + with open(image_path[i], 'rb') as image_file: + image_file.read(16) + image = np.fromfile(image_file, dtype=np.uint8) + image = image.reshape(-1, 28, 28, 1) + image[image > 0] = 255 # Perform binarization to maintain consistency with our API + images.append(image) + with open(label_path[i], 'rb') as label_file: + label_file.read(12) + label = np.fromfile(label_file, dtype='>u4') + label = label.reshape(-1, 8) + labels.append(label) + + images = np.concatenate(images, 0) + labels = np.concatenate(labels, 0) + if compat: + return images, labels[:, 0] + return images, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) + plt.title(labels[i]) + plt.show() + + +def test_qmnist_content_check(): + """ + Validate QMnistDataset image readings + """ + logger.info("Test QMnistDataset Op with content check") + for usage in ["train", "test", "nist", "all"]: + data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False) + images, labels = load_qmnist(DATA_DIR, usage, True) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + image_list, label_list = [], [] + for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): + image_list.append(data["image"]) + label_list.append("label {}".format(data["label"])) + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 10 + + for usage in ["train", "test", "nist", "all"]: + data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False) + images, labels = load_qmnist(DATA_DIR, usage, False) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + image_list, label_list = [], [] + for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): + image_list.append(data["image"]) + label_list.append("label {}".format(data["label"])) + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 10 + + +def test_qmnist_basic(): + """ + Validate QMnistDataset + """ + logger.info("Test QMnistDataset Op") + + # case 1: test loading whole dataset + data1 = ds.QMnistDataset(DATA_DIR, "train", True) + num_iter1 = 0 + for _ in data1.create_dict_iterator(num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 10 + + # case 2: test num_samples + data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5) + num_iter2 = 0 + for _ in data2.create_dict_iterator(num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 5 + + # case 3: test repeat + data3 = ds.QMnistDataset(DATA_DIR, "train", True) + data3 = data3.repeat(5) + num_iter3 = 0 + for _ in data3.create_dict_iterator(num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 50 + + # case 4: test batch with drop_remainder=False + data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert data4.get_dataset_size() == 10 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=7) # drop_remainder is default to be False + assert data4.get_dataset_size() == 2 + assert data4.get_batch_size() == 7 + num_iter4 = 0 + for _ in data4.create_dict_iterator(num_epochs=1): + num_iter4 += 1 + assert num_iter4 == 2 + + # case 5: test batch with drop_remainder=True + data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert data5.get_dataset_size() == 10 + assert data5.get_batch_size() == 1 + data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped + assert data5.get_dataset_size() == 3 + assert data5.get_batch_size() == 3 + num_iter5 = 0 + for _ in data5.create_dict_iterator(num_epochs=1): + num_iter5 += 1 + assert num_iter5 == 3 + + # case 6: test get_col_names + dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert dataset.get_col_names() == ["image", "label"] + + +def test_qmnist_pk_sampler(): + """ + Test QMnistDataset with PKSampler + """ + logger.info("Test QMnistDataset Op with PKSampler") + golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + sampler = ds.PKSampler(10) + data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler) + num_iter = 0 + label_list = [] + for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): + label_list.append(item["label"]) + num_iter += 1 + np.testing.assert_array_equal(golden, label_list) + assert num_iter == 10 + + +def test_qmnist_sequential_sampler(): + """ + Test QMnistDataset with SequentialSampler + """ + logger.info("Test QMnistDataset Op with SequentialSampler") + num_samples = 10 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler) + data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples) + label_list1, label_list2 = [], [] + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)): + label_list1.append(item1["label"].asnumpy()) + label_list2.append(item2["label"].asnumpy()) + num_iter += 1 + np.testing.assert_array_equal(label_list1, label_list2) + assert num_iter == num_samples + + +def test_qmnist_exception(): + """ + Test error cases for QMnistDataset + """ + logger.info("Test error cases for MnistDataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data files" + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) + data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + + +def test_qmnist_visualize(plot=False): + """ + Visualize QMnistDataset results + """ + logger.info("Test QMnistDataset visualization") + + data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (28, 28, 1) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 10 + if plot: + visualize_dataset(image_list, label_list) + + +def test_qmnist_usage(): + """ + Validate QMnistDataset image readings + """ + logger.info("Test QMnistDataset usage flag") + + def test_config(usage, path=None): + path = DATA_DIR if path is None else path + try: + data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("train") == 10 + assert test_config("test") == 10 + assert test_config("nist") == 10 + assert test_config("all") == 30 + assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in\ + test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + +if __name__ == '__main__': + test_qmnist_content_check() + test_qmnist_basic() + test_qmnist_pk_sampler() + test_qmnist_sequential_sampler() + test_qmnist_exception() + test_qmnist_visualize(plot=True) + test_qmnist_usage()