[feat][assistant][I3J6V2] add new data operator EMnist

This commit is contained in:
G-Dragon-Liu 2021-08-19 10:23:15 +00:00
parent 2c1d3baace
commit 32baa520bf
22 changed files with 1611 additions and 3 deletions

View File

@ -96,6 +96,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
@ -1042,6 +1043,33 @@ DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vect
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -33,6 +33,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
@ -152,6 +153,17 @@ PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
(void)py::class_<EMnistNode, DatasetNode, std::shared_ptr<EMnistNode>>(*m, "EMnistNode",
"to create an EMnistNode")
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) {
auto emnist =
std::make_shared<EMnistNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(emnist->ValidateParams());
return emnist;
}));
}));
PYBIND_REGISTER(
FlickrNode, 2, ([](const py::module *m) {
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")

View File

@ -22,6 +22,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
div2k_op.cc
flickr_op.cc
qmnist_op.cc
emnist_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/emnist_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
EMnistOp::EMnistOp(const std::string &name, const std::string &usage, int32_t num_workers,
const std::string &folder_path, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: MnistOp(usage, num_workers, folder_path, queue_size, std::move(data_schema), std::move(sampler)), name_(name) {}
void EMnistOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nNumber of rows:" << num_rows_ << "\n"
<< DatasetName(true) << " directory: " << folder_path_ << "\nName: " << name_ << "\nUsage: " << usage_
<< "\n\n";
}
}
Status EMnistOp::WalkAllFiles() {
const std::string img_ext = "-images-idx3-ubyte";
const std::string lbl_ext = "-labels-idx1-ubyte";
const std::string train_prefix = "-train";
const std::string test_prefix = "-test";
auto realpath = FileUtils::GetRealPath(folder_path_.data());
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed: " + folder_path_);
Path dir(realpath.value());
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
if (dir_it == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid path, failed to open directory: " + dir.ToString());
}
std::string prefix;
prefix = "emnist-" + name_; // used to match usage == "all".
if (usage_ == "train" || usage_ == "test") {
prefix += (usage_ == "test" ? test_prefix : train_prefix);
}
if (dir_it != nullptr) {
while (dir_it->HasNext()) {
Path file = dir_it->Next();
std::string fname = file.Basename(); // name of the emnist file.
if ((fname.find(prefix) != std::string::npos) && (fname.find(img_ext) != std::string::npos)) {
image_names_.push_back(file.ToString());
MS_LOG(INFO) << DatasetName(true) << " operator found image file at " << fname << ".";
} else if ((fname.find(prefix) != std::string::npos) && (fname.find(lbl_ext) != std::string::npos)) {
label_names_.push_back(file.ToString());
MS_LOG(INFO) << DatasetName(true) << " operator found label file at " << fname << ".";
}
}
} else {
MS_LOG(WARNING) << DatasetName(true) << " operator unable to open directory " << dir.ToString() << ".";
}
std::sort(image_names_.begin(), image_names_.end());
std::sort(label_names_.begin(), label_names_.end());
CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(),
"Invalid data, num of images does not equal to num of labels.");
return Status::OK();
}
Status EMnistOp::CountTotalRows(const std::string &dir, const std::string &name, const std::string &usage,
int64_t *count) {
// the logic of counting the number of samples is copied from ParseEMnistData() and uses CheckReader().
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
auto op =
std::make_shared<EMnistOp>(name, usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
RETURN_IF_NOT_OK(op->WalkAllFiles());
for (size_t i = 0; i < op->image_names_.size(); ++i) {
std::ifstream image_reader;
image_reader.open(op->image_names_[i], std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(image_reader.is_open(),
"Invalid file, failed to open image file: " + op->image_names_[i]);
std::ifstream label_reader;
label_reader.open(op->label_names_[i], std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(label_reader.is_open(),
"Invalid file, failed to open label file: " + op->label_names_[i]);
uint32_t num_images;
Status s = op->CheckImage(op->image_names_[i], &image_reader, &num_images);
image_reader.close();
RETURN_IF_NOT_OK(s);
uint32_t num_labels;
s = op->CheckLabel(op->label_names_[i], &label_reader, &num_labels);
label_reader.close();
RETURN_IF_NOT_OK(s);
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
"Invalid data, num of images is not equal to num of labels.");
*count = *count + num_images;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EMNIST_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EMNIST_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;
class EMnistOp : public MnistOp {
public:
// Constructor.
// @param const std::string &name - Class of this dataset, can be
// "byclass","bymerge","balanced","letters","digits","mnist".
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
// @param int32_t num_workers - Number of workers reading images in parallel.
// @param const std::string &folder_path - Dir directory of emnist.
// @param int32_t queue_size - Connector queue size.
// @param std::unique_ptr<DataSchema> data_schema - The schema of the Emnist dataset.
// @param std::shared_ptr<SamplerRT> sampler - Sampler tells EMnistOp what to read.
EMnistOp(const std::string &name, const std::string &usage, int32_t num_workers, const std::string &folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
// Destructor.
~EMnistOp() = default;
// A print method typically used for debugging.
// @param std::ostream &out - Out stream.
// @param bool show_all - Whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the EMNIST dataset.
// @param const std::string &dir - Path to the EMNIST directory.
// @param const std::string &name - Class of this dataset, can be
// "byclass","bymerge","balanced","letters","digits","mnist".
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
// @param int64_t *count - Output arg that will hold the minimum of the actual dataset size and numSamples.
// @return Status The status code returned.
static Status CountTotalRows(const std::string &dir, const std::string &name, const std::string &usage,
int64_t *count);
// Op name getter.
// @return Name of the current Op.
std::string Name() const override { return "EMnistOp"; }
// DatasetName name getter.
// \return DatasetName of the current Op.
std::string DatasetName(bool upper = false) const override { return upper ? "EMnist" : "emnist"; }
private:
// Read all files in the directory.
// @return Status The status code returned.
Status WalkAllFiles() override;
const std::string name_; // can be "byclass", "bymerge", "balanced", "letters", "digits", "mnist".
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EMNIST_OP_H_

View File

@ -83,6 +83,7 @@ constexpr char kCLUENode[] = "CLUEDataset";
constexpr char kCocoNode[] = "CocoDataset";
constexpr char kCSVNode[] = "CSVDataset";
constexpr char kDIV2KNode[] = "DIV2KDataset";
constexpr char kEMnistNode[] = "EMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";

View File

@ -12,6 +12,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
coco_node.cc
csv_node.cc
div2k_node.cc
emnist_node.cc
flickr_node.cc
image_folder_node.cc
manifest_node.cc
@ -33,4 +34,4 @@ if(ENABLE_PYTHON)
)
endif()
add_library(engine-ir-datasetops-source OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES})
add_library(engine-ir-datasetops-source OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES})

View File

@ -0,0 +1,121 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/emnist_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
EMnistNode::EMnistNode(const std::string &dataset_dir, const std::string &name, const std::string &usage,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), name_(name), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> EMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<EMnistNode>(dataset_dir_, name_, usage_, sampler, cache_);
return node;
}
void EMnistNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status EMnistNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("EMnistNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("EMnistNode", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("EMnistNode", usage_, {"train", "test", "all"}));
RETURN_IF_NOT_OK(
ValidateStringValue("EMnistNode", name_, {"byclass", "bymerge", "balanced", "letters", "digits", "mnist"}));
return Status::OK();
}
Status EMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<EMnistOp>(name_, usage_, num_workers_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node.
Status EMnistNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size.
Status EMnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(EMnistOp::CountTotalRows(dataset_dir_, name_, usage_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status EMnistNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["name"] = name_;
args["usage"] = usage_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EMNIST_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EMNIST_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class EMnistNode : public MappableSourceNode {
public:
/// \brief Constructor.
/// \param[in] dataset_dir Dataset directory of emnist.
/// \param[in] name Class of this dataset, can be "byclass", "bymerge", "balanced", "letters", "digits", "mnist".
/// \param[in] usage Usage of this dataset, can be 'train', 'test' or 'all'.
/// \param[in] sampler Tells EMnistOp what to read.
/// \param[in] cache Tensor cache to use.
EMnistNode(const std::string &dataset_dir, const std::string &name, const std::string &usage,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~EMnistNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return "EMnistNode"; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
/// \return Dataset direction.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Getter functions.
/// \return Usage.
const std::string &Usage() const { return usage_; }
/// \brief Getter functions.
/// \return Name.
const std::string &GetName() const { return name_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Tells EMnistOp what to read.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string name_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EMNIST_NODE_H_

View File

@ -1628,6 +1628,93 @@ inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const
decode, sampler, cache);
}
/// \class EMnistDataset
/// \brief A source dataset for reading and parsing EMnist dataset.
class EMnistDataset : public Dataset {
public:
/// \brief Constructor of EMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits"
/// or "mnist".
/// \param[in] usage Part of dataset of EMNIST, can be "train", "test" or "all".
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
explicit EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of EMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits"
/// or "mnist".
/// \param[in] usage Part of dataset of EMNIST, can be "train", "test" or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of EMnistDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits"
/// or "mnist".
/// \param[in] usage Part of dataset of EMNIST, can be "train", "test" or "all".
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
explicit EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache);
~EMnistDataset() = default;
};
/// \brief Function to create a EMnistDataset.
/// \notes The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits" or "mnist".
/// \param[in] usage Usage of EMNIST, can be "train", "test" or "all" (default = "all").
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not.
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current EMnistDataset.
inline std::shared_ptr<EMnistDataset> EMnist(
const std::string &dataset_dir, const std::string &name, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
/// \brief Function to create a EMnistDataset.
/// \notes The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits" or "mnist".
/// \param[in] usage Usage of EMNIST, can be "train", "test" or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current EMnistDataset.
inline std::shared_ptr<EMnistDataset> EMnist(const std::string &dataset_dir, const std::string &usage,
const std::string &name, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
/// \brief Function to create a EMnistDataset.
/// \notes The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] name Name of splits for EMNIST, can be "byclass", "bymerge", "balanced", "letters", "digits" or "mnist".
/// \param[in] usage Usage of EMNIST, can be "train", "test" or "all".
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current EMnistDataset.
inline std::shared_ptr<EMnistDataset> EMnist(const std::string &dataset_dir, const std::string &name,
const std::string &usage, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
/// \class FlickrDataset
/// \brief A source dataset for reading and parsing Flickr dataset.
class FlickrDataset : public Dataset {

View File

@ -39,6 +39,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class CocoDataset;
friend class CSVDataset;
friend class DIV2KDataset;
friend class EMnistDataset;
friend class FlickrDataset;
friend class ImageFolderDataset;
friend class ManifestDataset;

View File

@ -66,7 +66,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -6350,6 +6350,138 @@ class PaddedDataset(GeneratorDataset):
self.padded_samples = padded_samples
class EMnistDataset(MappableDataset):
"""
A source dataset for reading and parsing the EMNIST dataset.
The generated dataset has two columns :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is a scalar of the uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
name (str): Name of splits for this dataset, can be "byclass", "bymerge", "balanced", "letters", "digits"
or "mnist".
usage (str, optional): Usage of this dataset, can be "train", "test" or "all".
(default=None, will read all samples).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will read all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, will use value set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> emnist_dataset_dir = "/path/to/emnist_dataset_directory"
>>>
>>> # Read 3 samples from EMNIST dataset
>>> dataset = ds.EMnistDataset(dataset_dir=emnist_dataset_dir, name="mnist", num_samples=3)
>>>
>>> # Note: In emnist_dataset dataset, each dictionary has keys "image" and "label"
About EMNIST dataset:
The EMNIST dataset is a set of handwritten character digits derived from the NIST Special
Database 19 and converted to a 28x28 pixel image format and dataset structure that directly
matches the MNIST dataset. Further information on the dataset contents and conversion process
can be found in the paper available at https://arxiv.org/abs/1702.05373v1.
The numbers of characters and classes of each split of EMNIST are as follows:
By Class: 814,255 characters and 62 unbalanced classes.
By Merge: 814,255 characters and 47 unbalanced classes.
Balanced: 131,600 characters and 47 balanced classes.
Letters: 145,600 characters and 26 balanced classes.
Digits: 280,000 characters and 10 balanced classes.
MNIST: 70,000 characters and 10 balanced classes.
Here is the original EMNIST dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
mnist_dataset_dir
emnist-mnist-train-images-idx3-ubyte
emnist-mnist-train-labels-idx1-ubyte
emnist-mnist-test-images-idx3-ubyte
emnist-mnist-test-labels-idx1-ubyte
...
Citation:
.. code-block::
@article{cohen_afshar_tapson_schaik_2017,
title = {EMNIST: Extending MNIST to handwritten letters},
DOI = {10.1109/ijcnn.2017.7966217},
journal = {2017 International Joint Conference on Neural Networks (IJCNN)},
author = {Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
year = {2017},
howpublished = {https://www.westernsydney.edu.au/icns/reproducible_research/
publication_support_materials/emnist}
}
"""
@check_emnist_dataset
def __init__(self, dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.name = name
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.EMnistNode(self.dataset_dir, self.name, self.usage, self.sampler)
class FlickrDataset(MappableDataset):
"""
A source dataset for reading and parsing Flickr8k and Flickr30k dataset.

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1463,6 +1463,39 @@ def check_to_device_send(method):
return new_method
def check_emnist_dataset(method):
"""A wrapper that wraps a parameter checker emnist dataset"""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
name = param_dict.get('name')
check_valid_str(name, ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"], "name")
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_flickr_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""

View File

@ -23,6 +23,7 @@ SET(DE_UT_SRCS
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_div2k_test.cc
c_api_dataset_emnist_test.cc
c_api_dataset_flickr_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_manifest_test.cc

View File

@ -0,0 +1,368 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestEMnistTrainDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTrainDataset.";
// Create a EMnist Train Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "train", std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestEMnistTestDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTestDataset.";
// Create a EMNIST Test Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "train", std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestEMnistTrainDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTrainDatasetWithPipeline.";
// Create two Emnist Train Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds1 = EMnist(folder_path, "mnist", "train", std::make_shared<RandomSampler>(false, 5));
std::shared_ptr<Dataset> ds2 = EMnist(folder_path, "byclass", "train", std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestEMnistTestDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTestDatasetWithPipeline.";
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
// Create two EMnist Test Dataset
std::shared_ptr<Dataset> ds1 = EMnist(folder_path, "mnist", "test", std::make_shared<RandomSampler>(false, 5));
std::shared_ptr<Dataset> ds2 = EMnist(folder_path, "mnist", "test", std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestGetEMnistTrainDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetEMnistTrainDatasetSize.";
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
// Create a EMnist Train Dataset
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "train");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
std::shared_ptr<Dataset> ds2 = EMnist(folder_path, "byclass", "train");
EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestGetEMnistTestDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetEMnistTestDatasetSize.";
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
// Create a EMnist Test Dataset
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "test");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestEMnistTrainDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTrainDatasetGetters.";
// Create a EMnist Train Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "train");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<28,28,1>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestEMnistTestDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistTestDatasetGetters.";
// Create a EMnist Test Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "test");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<28,28,1>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 10);
}
TEST_F(MindDataTestPipeline, TestEMnistDatasetWithInvalidDir) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistDatasetWithInvalidDir.";
// Create a EMnist Dataset
std::shared_ptr<Dataset> ds = EMnist("", "mnist", "train", std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid EMnist input
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestEMnistDatasetWithInvalidUsage) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistDatasetWithInvalidUsage.";
// Create a EMnist Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "validation");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid EMnist input, validation is not a valid usage
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestEMnistDatasetWithInvalidName) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistDatasetWithInvalidName.";
// Create a EMnist Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "validation", "train");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid EMnist input, validation is not a valid name
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestEMnistDatasetWithNullSampler) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEMnistDatasetWithNullSampler.";
// Create a EMnist Dataset
std::string folder_path = datasets_root_path_ + "/testEMnistDataset";
std::shared_ptr<Dataset> ds = EMnist(folder_path, "mnist", "train", nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid EMnist input, sampler cannot be nullptr
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,481 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test EMnist dataset operators
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testEMnistDataset"
def load_emnist(path, usage, name):
"""
load EMnist data
"""
image_path = []
label_path = []
image_ext = "images-idx3-ubyte"
label_ext = "labels-idx1-ubyte"
train_prefix = "emnist-" + name + "-train-"
test_prefix = "emnist-" + name + "-test-"
assert usage in ["train", "test", "all"]
if usage == "train":
image_path.append(os.path.realpath(os.path.join(path, train_prefix + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, train_prefix + label_ext)))
elif usage == "test":
image_path.append(os.path.realpath(os.path.join(path, test_prefix + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, test_prefix + label_ext)))
elif usage == "all":
image_path.append(os.path.realpath(os.path.join(path, test_prefix + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, test_prefix + label_ext)))
image_path.append(os.path.realpath(os.path.join(path, train_prefix + image_ext)))
label_path.append(os.path.realpath(os.path.join(path, train_prefix + label_ext)))
assert len(image_path) == len(label_path)
images = []
labels = []
for i, _ in enumerate(image_path):
with open(image_path[i], 'rb') as image_file:
image_file.read(16)
image = np.fromfile(image_file, dtype=np.uint8)
image = image.reshape(-1, 28, 28, 1)
image[image > 0] = 255 # Perform binarization to maintain consistency with our API
images.append(image)
with open(label_path[i], 'rb') as label_file:
label_file.read(8)
label = np.fromfile(label_file, dtype=np.uint8)
labels.append(label)
images = np.concatenate(images, 0)
labels = np.concatenate(labels, 0)
return images, labels
def visualize_dataset(images, labels):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
plt.title(labels[i])
plt.show()
def test_emnist_content_check():
"""
Validate EMnistDataset image readings
"""
logger.info("Test EMnistDataset Op with content check")
# train mnist
train_data = ds.EMnistDataset(DATA_DIR, name="mnist", usage="train", num_samples=10, shuffle=False)
images, labels = load_emnist(DATA_DIR, "train", "mnist")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
image_list, label_list = [], []
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_list.append(data["image"])
label_list.append("label {}".format(data["label"]))
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 10
# train byclass
train_data = ds.EMnistDataset(DATA_DIR, name="byclass", usage="train", num_samples=10, shuffle=False)
images, labels = load_emnist(DATA_DIR, "train", "byclass")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
image_list, label_list = [], []
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_list.append(data["image"])
label_list.append("label {}".format(data["label"]))
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 10
# test
test_data = ds.EMnistDataset(DATA_DIR, name="mnist", usage="test", num_samples=10, shuffle=False)
images, labels = load_emnist(DATA_DIR, "test", "mnist")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
image_list, label_list = [], []
for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_list.append(data["image"])
label_list.append("label {}".format(data["label"]))
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 10
def test_emnist_basic():
"""
Validate EMnistDataset
"""
logger.info("Test EMnistDataset Op")
# case 1: test loading whole dataset
train_data = ds.EMnistDataset(DATA_DIR, "mnist", "train")
num_iter1 = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 10
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test")
num_iter = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 10
# case 2: test num_samples
train_data = ds.EMnistDataset(DATA_DIR, "byclass", "train", num_samples=5)
num_iter2 = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 5
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=5)
num_iter2 = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 5
# case 3: test repeat
train_data = ds.EMnistDataset(DATA_DIR, "byclass", "train", num_samples=2)
train_data = train_data.repeat(5)
num_iter3 = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 10
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=2)
test_data = test_data.repeat(5)
num_iter3 = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 10
# case 4: test batch with drop_remainder=False
train_data = ds.EMnistDataset(DATA_DIR, "byclass", "train", num_samples=10)
assert train_data.get_dataset_size() == 10
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=7) # drop_remainder is default to be False
assert train_data.get_dataset_size() == 2
assert train_data.get_batch_size() == 7
num_iter4 = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter4 += 1
assert num_iter4 == 2
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=10)
assert test_data.get_dataset_size() == 10
assert test_data.get_batch_size() == 1
test_data = test_data.batch(
batch_size=7) # drop_remainder is default to be False
assert test_data.get_dataset_size() == 2
assert test_data.get_batch_size() == 7
num_iter4 = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter4 += 1
assert num_iter4 == 2
# case 5: test batch with drop_remainder=True
train_data = ds.EMnistDataset(DATA_DIR, "byclass", "train", num_samples=10)
assert train_data.get_dataset_size() == 10
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
assert train_data.get_dataset_size() == 1
assert train_data.get_batch_size() == 7
num_iter5 = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter5 += 1
assert num_iter5 == 1
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=10)
assert test_data.get_dataset_size() == 10
assert test_data.get_batch_size() == 1
test_data = test_data.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
assert test_data.get_dataset_size() == 1
assert test_data.get_batch_size() == 7
num_iter5 = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter5 += 1
assert num_iter5 == 1
# case 6: test get_col_names
dataset = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=10)
assert dataset.get_col_names() == ["image", "label"]
def test_emnist_pk_sampler():
"""
Test EMnistDataset with PKSampler
"""
logger.info("Test EMnistDataset Op with PKSampler")
golden = [0, 0, 0, 1, 1, 1]
sampler = ds.PKSampler(3)
train_data = ds.EMnistDataset(DATA_DIR, "mnist", "train", sampler=sampler)
num_iter = 0
label_list = []
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
label_list.append(item["label"])
num_iter += 1
np.testing.assert_array_equal(golden, label_list)
assert num_iter == 6
sampler = ds.PKSampler(3)
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "train", sampler=sampler)
num_iter = 0
label_list = []
for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True):
label_list.append(item["label"])
num_iter += 1
np.testing.assert_array_equal(golden, label_list)
assert num_iter == 6
def test_emnist_sequential_sampler():
"""
Test EMnistDataset with SequentialSampler
"""
logger.info("Test EMnistDataset Op with SequentialSampler")
num_samples = 10
sampler = ds.SequentialSampler(num_samples=num_samples)
train_data1 = ds.EMnistDataset(DATA_DIR, "mnist", "train", sampler=sampler)
train_data2 = ds.EMnistDataset(DATA_DIR, "mnist", "train", shuffle=False, num_samples=num_samples)
label_list1, label_list2 = [], []
num_iter = 0
for item1, item2 in zip(train_data1.create_dict_iterator(num_epochs=1),
train_data2.create_dict_iterator(num_epochs=1)):
label_list1.append(item1["label"].asnumpy())
label_list2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list1, label_list2)
assert num_iter == num_samples
num_samples = 10
sampler = ds.SequentialSampler(num_samples=num_samples)
test_data1 = ds.EMnistDataset(DATA_DIR, "mnist", "test", sampler=sampler)
test_data2 = ds.EMnistDataset(DATA_DIR, "mnist", "test", shuffle=False, num_samples=num_samples)
label_list1, label_list2 = [], []
num_iter = 0
for item1, item2 in zip(test_data1.create_dict_iterator(num_epochs=1),
test_data2.create_dict_iterator(num_epochs=1)):
label_list1.append(item1["label"].asnumpy())
label_list2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list1, label_list2)
assert num_iter == num_samples
def test_emnist_exception():
"""
Test error cases for EMnistDataset
"""
logger.info("Test error cases for EMnistDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.EMnistDataset(DATA_DIR, "byclass", "train", shuffle=False, sampler=ds.PKSampler(3))
ds.EMnistDataset(DATA_DIR, "mnist", "test", shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.EMnistDataset(DATA_DIR, "mnist", "train", sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
ds.EMnistDataset(DATA_DIR, "mnist", "test", sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.EMnistDataset(DATA_DIR, "byclass", "train", num_shards=10)
ds.EMnistDataset(DATA_DIR, "mnist", "test", num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.EMnistDataset(DATA_DIR, "mnist", "train", shard_id=0)
ds.EMnistDataset(DATA_DIR, "mnist", "test", shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.EMnistDataset(DATA_DIR, "byclass", "train", num_shards=5, shard_id=-1)
ds.EMnistDataset(DATA_DIR, "mnist", "test", num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.EMnistDataset(DATA_DIR, "mnist", "train", num_shards=5, shard_id=5)
ds.EMnistDataset(DATA_DIR, "mnist", "test", num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.EMnistDataset(DATA_DIR, "byclass", "train", num_shards=2, shard_id=5)
ds.EMnistDataset(DATA_DIR, "mnist", "test", num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.EMnistDataset(DATA_DIR, "mnist", "train", shuffle=False, num_parallel_workers=0)
ds.EMnistDataset(DATA_DIR, "mnist", "test", shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.EMnistDataset(DATA_DIR, "byclass", "train", shuffle=False, num_parallel_workers=256)
ds.EMnistDataset(DATA_DIR, "mnist", "test", shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.EMnistDataset(DATA_DIR, "mnist", "train", shuffle=False, num_parallel_workers=-2)
ds.EMnistDataset(DATA_DIR, "mnist", "test", shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.EMnistDataset(DATA_DIR, "mnist", "train", num_shards=2, shard_id="0")
ds.EMnistDataset(DATA_DIR, "mnist", "test", num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
error_msg_8 = "The corresponding data files"
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.EMnistDataset(DATA_DIR, "mnist", "train")
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.EMnistDataset(DATA_DIR, "mnist", "train")
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.EMnistDataset(DATA_DIR, "mnist", "train")
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
for _ in data.__iter__():
pass
def test_emnist_visualize(plot=False):
"""
Visualize EMnistDataset results
"""
logger.info("Test EMnistDataset visualization")
train_data = ds.EMnistDataset(DATA_DIR, "mnist", "train", num_samples=10, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (28, 28, 1)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 10
if plot:
visualize_dataset(image_list, label_list)
test_data = ds.EMnistDataset(DATA_DIR, "mnist", "test", num_samples=10, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (28, 28, 1)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 10
if plot:
visualize_dataset(image_list, label_list)
def test_emnist_usage():
"""
Validate EMnistDataset image readings
"""
logger.info("Test EMnistDataset usage flag")
def test_config(usage, emnist_path=None):
emnist_path = DATA_DIR if emnist_path is None else emnist_path
try:
data = ds.EMnistDataset(emnist_path, "mnist", usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("train") == 10
assert test_config("test") == 10
assert test_config("all") == 20
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
# change this directory to the folder that contains all emnist files
all_files_path = None
# the following tests on the entire datasets
if all_files_path is not None:
assert test_config("train", all_files_path) == 10000
assert test_config("test", all_files_path) == 60000
assert test_config("all", all_files_path) == 70000
assert ds.EMnistDataset(all_files_path, "mnist", usage="test").get_dataset_size() == 10000
assert ds.EMnistDataset(all_files_path, "mnist", usage="test").get_dataset_size() == 60000
assert ds.EMnistDataset(all_files_path, "mnist", usage="all").get_dataset_size() == 70000
def test_emnist_name():
"""
Validate EMnistDataset image readings
"""
def test_config(name, usage, emnist_path=None):
emnist_path = DATA_DIR if emnist_path is None else emnist_path
try:
data = ds.EMnistDataset(emnist_path, name, usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("mnist", "train") == 10
assert test_config("mnist", "test") == 10
assert test_config("byclass", "train") == 10
assert "name is not within the valid set of " + \
"['byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist']" in test_config("invalid", "train")
assert "Argument name with value ['list'] is not of type [<class 'str'>]" in test_config(["list"], "train")
if __name__ == '__main__':
test_emnist_content_check()
test_emnist_basic()
test_emnist_pk_sampler()
test_emnist_sequential_sampler()
test_emnist_exception()
test_emnist_visualize(plot=True)
test_emnist_usage()
test_emnist_name()