forked from mindspore-Ecosystem/mindspore
!18018 [assistant][ops]New operator implementation, include FashionMnistDataset
Merge pull request !18018 from 张璇/fashionmnist_dataset
This commit is contained in:
commit
5d8be044c9
|
@ -98,6 +98,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
@ -1097,6 +1098,29 @@ FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
@ -179,6 +180,17 @@ PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FashionMnistNode, DatasetNode, std::shared_ptr<FashionMnistNode>>(
|
||||
*m, "FashionMnistNode", "to create a FashionMnistNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
||||
auto fashion_mnist =
|
||||
std::make_shared<FashionMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(fashion_mnist->ValidateParams());
|
||||
return fashion_mnist;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FlickrNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
|
||||
|
|
|
@ -26,6 +26,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
fake_image_op.cc
|
||||
places365_op.cc
|
||||
photo_tour_op.cc
|
||||
fashion_mnist_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/fashion_mnist_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
FashionMnistOp::FashionMnistOp(const std::string &usage, int32_t num_workers, const std::string &folder_path,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
: MnistOp(usage, num_workers, folder_path, queue_size, std::move(data_schema), std::move(sampler)) {}
|
||||
|
||||
Status FashionMnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader().
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
auto op =
|
||||
std::make_shared<FashionMnistOp>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
|
||||
|
||||
RETURN_IF_NOT_OK(op->WalkAllFiles());
|
||||
|
||||
for (size_t i = 0; i < op->image_names_.size(); ++i) {
|
||||
std::ifstream image_reader;
|
||||
image_reader.open(op->image_names_[i], std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_reader.is_open(),
|
||||
"Invalid file, failed to open image file: " + op->image_names_[i]);
|
||||
std::ifstream label_reader;
|
||||
label_reader.open(op->label_names_[i], std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(label_reader.is_open(),
|
||||
"Invalid file, failed to open label file: " + op->label_names_[i]);
|
||||
uint32_t num_images;
|
||||
Status s = op->CheckImage(op->image_names_[i], &image_reader, &num_images);
|
||||
image_reader.close();
|
||||
RETURN_IF_NOT_OK(s);
|
||||
|
||||
uint32_t num_labels;
|
||||
s = op->CheckLabel(op->label_names_[i], &label_reader, &num_labels);
|
||||
label_reader.close();
|
||||
RETURN_IF_NOT_OK(s);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels),
|
||||
"Invalid data, num of images is not equal to num of labels.");
|
||||
*count = *count + num_images;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FASHION_MNIST_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FASHION_MNIST_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Forward declares.
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
class FashionMnistOp : public MnistOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] usage Usage of this dataset, can be 'train', 'test' or 'all'.
|
||||
/// \param[in] num_workers Number of workers reading images in parallel.
|
||||
/// \param[in] folder_path Dir directory of fashionmnist.
|
||||
/// \param[in] queue_size Connector queue size.
|
||||
/// \param[in] data_schema The schema of the fashionmnist dataset.
|
||||
/// \param[in] Sampler Tells FashionMnistOp what to read.
|
||||
FashionMnistOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~FashionMnistOp() = default;
|
||||
|
||||
/// \brief Function to count the number of samples in the Fashion-MNIST dataset.
|
||||
/// \param[in] dir Path to the Fashion-MNIST directory.
|
||||
/// \param[in] usage Usage of this dataset, can be 'train', 'test' or 'all'.
|
||||
/// \param[in] count Output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
/// \return Status The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "FashionMnistOp"; }
|
||||
|
||||
/// \brief Dataset name getter.
|
||||
/// \param[in] upper Whether to get upper name.
|
||||
/// \return Dataset name of the current Op.
|
||||
std::string DatasetName(bool upper = false) const override { return upper ? "FashionMnist" : "fashion mnist"; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FASHION_MNIST_OP_H_
|
|
@ -85,6 +85,7 @@ constexpr char kCSVNode[] = "CSVDataset";
|
|||
constexpr char kDIV2KNode[] = "DIV2KDataset";
|
||||
constexpr char kEMnistNode[] = "EMnistDataset";
|
||||
constexpr char kFakeImageNode[] = "FakeImageDataset";
|
||||
constexpr char kFashionMnistNode[] = "FashionMnistDataset";
|
||||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
|
|
|
@ -14,6 +14,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
div2k_node.cc
|
||||
emnist_node.cc
|
||||
fake_image_node.cc
|
||||
fashion_mnist_node.cc
|
||||
flickr_node.cc
|
||||
image_folder_node.cc
|
||||
manifest_node.cc
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/fashion_mnist_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
FashionMnistNode::FashionMnistNode(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> FashionMnistNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<FashionMnistNode>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void FashionMnistNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status FashionMnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("FashionMnistNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FashionMnistNode", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("FashionMnistNode", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FashionMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<FashionMnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node.
|
||||
Status FashionMnistNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size.
|
||||
Status FashionMnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(FashionMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FashionMnistNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FASHION_MNIST_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FASHION_MNIST_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class FashionMnistNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
FashionMnistNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~FashionMnistNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kFashionMnistNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[in] out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[out] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting.
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FASHION_MNIST_NODE_H_
|
|
@ -2109,6 +2109,81 @@ inline std::shared_ptr<FakeImageDataset> FakeImage(int32_t num_images, const std
|
|||
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class FashionMnistDataset
|
||||
/// \brief A source dataset that reads and parses FASHION-MNIST dataset.
|
||||
class FashionMnistDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of FashionMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all".
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of FashionMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of FashionMnistDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of FashionMnistDataset.
|
||||
~FashionMnistDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a FashionMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all" (default = "all").
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the FashionMnistDataset.
|
||||
inline std::shared_ptr<FashionMnistDataset> FashionMnist(
|
||||
const std::string &dataset_dir, const std::string &usage = "all",
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FashionMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the FashionMnistDataset.
|
||||
inline std::shared_ptr<FashionMnistDataset> FashionMnist(const std::string &dataset_dir, const std::string &usage,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FashionMnistDataset.
|
||||
/// \note The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of FASHION-MNIST, can be "train", "test" or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the FashionMnistDataset.
|
||||
inline std::shared_ptr<FashionMnistDataset> FashionMnist(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class FlickrDataset
|
||||
/// \brief A source dataset for reading and parsing Flickr dataset.
|
||||
class FlickrDataset : public Dataset {
|
||||
|
|
|
@ -41,6 +41,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class DIV2KDataset;
|
||||
friend class EMnistDataset;
|
||||
friend class FakeImageDataset;
|
||||
friend class FashionMnistDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class ManifestDataset;
|
||||
|
|
|
@ -3279,6 +3279,128 @@ class RangeDataset(MappableDataset):
|
|||
return self.dataset_size
|
||||
|
||||
|
||||
class FashionMnistDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the FASHION-MNIST dataset.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is a scalar of the uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be `train`, `test` or `all`. `train` will read from 60,000
|
||||
train samples, `test` will read from 10,000 test samples, `all` will read from all 70,000 samples.
|
||||
(default=None, will read all samples)
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> fashion_mnist_dataset_dir = "/path/to/fashion_mnist_dataset_directory"
|
||||
>>>
|
||||
>>> # Read 3 samples from FASHIONMNIST dataset
|
||||
>>> dataset = ds.FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # Note: In FASHIONMNIST dataset, each dictionary has keys "image" and "label"
|
||||
|
||||
About Fashion-MNIST dataset:
|
||||
|
||||
Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and
|
||||
a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.
|
||||
We intend Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking
|
||||
machine learning algorithms. It shares the same image size and structure of training and testing splits.
|
||||
|
||||
Here is the original Fashion-MNIST dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── fashionmnist_dataset_dir
|
||||
├── t10k-images-idx3-ubyte
|
||||
├── t10k-labels-idx1-ubyte
|
||||
├── train-images-idx3-ubyte
|
||||
└── train-labels-idx1-ubyte
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@online{xiao2017/online,
|
||||
author = {Han Xiao and Kashif Rasul and Roland Vollgraf},
|
||||
title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms},
|
||||
date = {2017-08-28},
|
||||
year = {2017},
|
||||
eprintclass = {cs.LG},
|
||||
eprinttype = {arXiv},
|
||||
eprint = {cs.LG/1708.07747},
|
||||
}
|
||||
"""
|
||||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.FashionMnistNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class ImageFolderDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a tree of directories.
|
||||
|
|
|
@ -25,6 +25,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_div2k_test.cc
|
||||
c_api_dataset_emnist_test.cc
|
||||
c_api_dataset_fake_image_test.cc
|
||||
c_api_dataset_fashion_mnist_test.cc
|
||||
c_api_dataset_flickr_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
|
|
|
@ -0,0 +1,287 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: FashionMnistTestDataset.
|
||||
/// Description: test basic usage of FashionMnistTestDataset.
|
||||
/// Expectation: get correct data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistTestDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistTestDataset.";
|
||||
|
||||
// Create a FashionMnist Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "test", std::make_shared<RandomSampler>(false, 10));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistTestDatasetWithPipeline.
|
||||
/// Description: test FashionMnistTestDataset with pipeline.
|
||||
/// Expectation: get correct data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistTestDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistTestDatasetWithPipeline.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
|
||||
// Create two FashionMnist Test Dataset
|
||||
std::shared_ptr<Dataset> ds1 = FashionMnist(folder_path, "test", std::make_shared<RandomSampler>(false, 10));
|
||||
std::shared_ptr<Dataset> ds2 = FashionMnist(folder_path, "test", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistIteratorOneColumn.
|
||||
/// Description: test iterator of FashionMnistDataset with only the "image" column.
|
||||
/// Expectation: get correct data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistIteratorOneColumn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistIteratorOneColumn.";
|
||||
// Create a FashionMnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "all", std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 2;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// Only select "image" column and drop others
|
||||
std::vector<std::string> columns = {"image"};
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns, -1);
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::vector<mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
std::vector<int64_t> expect_image = {2, 28, 28, 1};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (auto &v : row) {
|
||||
MS_LOG(INFO) << "image shape:" << v.Shape();
|
||||
EXPECT_EQ(expect_image, v.Shape());
|
||||
}
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistTestDatasetSize.
|
||||
/// Description: test usage of get the size of FashionMnistTestDataset.
|
||||
/// Expectation: get correct data.
|
||||
TEST_F(MindDataTestPipeline, TestGetFashionMnistTestDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetFashionMnistTestDatasetSize.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
|
||||
// Create a FashionMnist Test Dataset
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "test");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistTestDatasetGetters.
|
||||
/// Description: test DatasetGetters of FashionMnistTestDataset.
|
||||
/// Expectation: get correct the value.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistTestDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistTestDatasetGetters.";
|
||||
|
||||
// Create a FashionMnist Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "test");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<28,28,1>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistIteratorWrongColumn.
|
||||
/// Description: test iterator of FashionMnistDataset with wrong column.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistIteratorWrongColumn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistIteratorOneColumn.";
|
||||
// Create a FashionMnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "all", std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Pass wrong column name
|
||||
std::vector<std::string> columns = {"digital"};
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns);
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistDatasetFail.
|
||||
/// Description: test failure of FashionMnistDataset.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistDatasetFail.";
|
||||
|
||||
// Create a FashionMnist Dataset
|
||||
std::shared_ptr<Dataset> ds = FashionMnist("", "train", std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FashionMnist input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistDatasetWithInvalidUsageFail.
|
||||
/// Description: test FashionMnistDataset with invalid usage.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistDatasetWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistDatasetWithInvalidUsageFail.";
|
||||
|
||||
// Create a FashionMnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "validation");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FashionMnist input, validation is not a valid usage
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FashionMnistDatasetWithNullSamplerFail.
|
||||
/// Description: test FashionMnistDataset with null sampler.
|
||||
/// Expectation: get none piece of data.
|
||||
TEST_F(MindDataTestPipeline, TestFashionMnistDatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFashionMnistUDatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a FashionMnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = FashionMnist(folder_path, "all", nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FashionMnist input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test FashionMnist dataset operators
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
|
||||
def load_fashion_mnist(path):
|
||||
"""
|
||||
Feature: load_fashion_mnist.
|
||||
Description: load FashionMnistDataset.
|
||||
Expectation: get data of FashionMnistDataset.
|
||||
"""
|
||||
labels_path = os.path.realpath(os.path.join(path, 't10k-labels-idx1-ubyte'))
|
||||
images_path = os.path.realpath(os.path.join(path, 't10k-images-idx3-ubyte'))
|
||||
with open(labels_path, 'rb') as lbpath:
|
||||
lbpath.read(8)
|
||||
labels = np.fromfile(lbpath, dtype=np.uint8)
|
||||
with open(images_path, 'rb') as imgpath:
|
||||
imgpath.read(16)
|
||||
images = np.fromfile(imgpath, dtype=np.uint8)
|
||||
images = images.reshape(-1, 28, 28, 1)
|
||||
images[images > 0] = 255 # Perform binarization to maintain consistency with our API
|
||||
return images, labels
|
||||
|
||||
|
||||
def visualize_dataset(images, labels):
|
||||
"""
|
||||
Feature: visualize_dataset.
|
||||
Description: visualize FashionMnistDataset.
|
||||
Expectation: plot images.
|
||||
"""
|
||||
num_samples = len(images)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(labels[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_fashion_mnist_content_check():
|
||||
"""
|
||||
Feature: test_fashion_mnist_content_check.
|
||||
Description: validate FashionMnistDataset image readings.
|
||||
Expectation: get correct value.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset Op with content check")
|
||||
data1 = ds.FashionMnistDataset(DATA_DIR, num_samples=100, shuffle=False)
|
||||
images, labels = load_fashion_mnist(DATA_DIR)
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
image_list, label_list = [], []
|
||||
for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_list.append(data["image"])
|
||||
label_list.append("label {}".format(data["label"]))
|
||||
np.testing.assert_array_equal(data["image"], images[i])
|
||||
np.testing.assert_array_equal(data["label"], labels[i])
|
||||
num_iter += 1
|
||||
assert num_iter == 100
|
||||
|
||||
|
||||
def test_fashion_mnist_basic():
|
||||
"""
|
||||
Feature: test_fashion_mnist_basic.
|
||||
Description: test basic usage of FashionMnistDataset.
|
||||
Expectation: get correct data.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
data1 = ds.FashionMnistDataset(DATA_DIR)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 10000
|
||||
|
||||
# case 2: test num_samples
|
||||
data2 = ds.FashionMnistDataset(DATA_DIR, num_samples=500)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 500
|
||||
|
||||
# case 3: test repeat
|
||||
data3 = ds.FashionMnistDataset(DATA_DIR, num_samples=200)
|
||||
data3 = data3.repeat(5)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 1000
|
||||
|
||||
# case 4: test batch with drop_remainder=False
|
||||
data4 = ds.FashionMnistDataset(DATA_DIR, num_samples=100)
|
||||
assert data4.get_dataset_size() == 100
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
|
||||
assert data4.get_dataset_size() == 15
|
||||
assert data4.get_batch_size() == 7
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 15
|
||||
|
||||
# case 5: test batch with drop_remainder=True
|
||||
data5 = ds.FashionMnistDataset(DATA_DIR, num_samples=100)
|
||||
assert data5.get_dataset_size() == 100
|
||||
assert data5.get_batch_size() == 1
|
||||
data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data5.get_dataset_size() == 14
|
||||
assert data5.get_batch_size() == 7
|
||||
num_iter5 = 0
|
||||
for _ in data5.create_dict_iterator(num_epochs=1):
|
||||
num_iter5 += 1
|
||||
assert num_iter5 == 14
|
||||
|
||||
# case 6: test get_col_names
|
||||
data6 = ds.FashionMnistDataset(DATA_DIR, "train", num_samples=10)
|
||||
assert data6.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_fashion_mnist_pk_sampler():
|
||||
"""
|
||||
Feature: test_fashion_mnist_pk_sampler.
|
||||
Description: test usage of FashionMnistDataset with PKSampler.
|
||||
Expectation: get correct data.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset Op with PKSampler")
|
||||
golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
|
||||
5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
|
||||
sampler = ds.PKSampler(3)
|
||||
data = ds.FashionMnistDataset(DATA_DIR, sampler=sampler)
|
||||
num_iter = 0
|
||||
label_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
label_list.append(item["label"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(golden, label_list)
|
||||
assert num_iter == 30
|
||||
|
||||
|
||||
def test_fashion_mnist_sequential_sampler():
|
||||
"""
|
||||
Feature: test_fashion_mnist_sequential_sampler.
|
||||
Description: test usage of FashionMnistDataset with SequentialSampler.
|
||||
Expectation: get correct data.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset Op with SequentialSampler")
|
||||
num_samples = 50
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.FashionMnistDataset(DATA_DIR, sampler=sampler)
|
||||
data2 = ds.FashionMnistDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
|
||||
label_list1, label_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
|
||||
label_list1.append(item1["label"].asnumpy())
|
||||
label_list2.append(item2["label"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(label_list1, label_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_fashion_mnist_exception():
|
||||
"""
|
||||
Feature: test_fashion_mnist_exception.
|
||||
Description: test error cases for FashionMnistDataset.
|
||||
Expectation: raise exception.
|
||||
"""
|
||||
logger.info("Test error cases for FashionMnistDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.FashionMnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.FashionMnistDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.FashionMnistDataset(DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.FashionMnistDataset(DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FashionMnistDataset(DATA_DIR, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FashionMnistDataset(DATA_DIR, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FashionMnistDataset(DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FashionMnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FashionMnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FashionMnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.FashionMnistDataset(DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
error_msg_8 = "The corresponding data files"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.FashionMnistDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.FashionMnistDataset(DATA_DIR)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.FashionMnistDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_fashion_mnist_visualize(plot=False):
|
||||
"""
|
||||
Feature: test_fashion_mnist_visualize.
|
||||
Description: visualize FashionMnistDataset results.
|
||||
Expectation: get correct data and plot them.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset visualization")
|
||||
|
||||
data1 = ds.FashionMnistDataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, label_list = [], []
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
label = item["label"]
|
||||
image_list.append(image)
|
||||
label_list.append("label {}".format(label))
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.shape == (28, 28, 1)
|
||||
assert image.dtype == np.uint8
|
||||
assert label.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
if plot:
|
||||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_fashion_mnist_usage():
|
||||
"""
|
||||
Feature: test_fashion_mnist_usage.
|
||||
Description: validate FashionMnistDataset image readings.
|
||||
Expectation: get correct data.
|
||||
"""
|
||||
logger.info("Test FashionMnistDataset usage flag")
|
||||
|
||||
def test_config(usage, fashion_mnist_path=None):
|
||||
fashion_mnist_path = DATA_DIR if fashion_mnist_path is None else fashion_mnist_path
|
||||
try:
|
||||
data = ds.FashionMnistDataset(fashion_mnist_path, usage=usage, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("test") == 10000
|
||||
assert test_config("all") == 10000
|
||||
assert "FashionMnistDataset API can't read the data file (interface mismatch or no data found)" \
|
||||
in test_config("train")
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
# change this directory to the folder that contains all fashionmnist files
|
||||
all_files_path = None
|
||||
# the following tests on the entire datasets
|
||||
if all_files_path is not None:
|
||||
assert test_config("train", all_files_path) == 60000
|
||||
assert test_config("test", all_files_path) == 10000
|
||||
assert test_config("all", all_files_path) == 70000
|
||||
assert ds.FashionMnistDataset(all_files_path, usage="train").get_dataset_size() == 60000
|
||||
assert ds.FashionMnistDataset(all_files_path, usage="test").get_dataset_size() == 10000
|
||||
assert ds.FashionMnistDataset(all_files_path, usage="all").get_dataset_size() == 70000
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fashion_mnist_content_check()
|
||||
test_fashion_mnist_basic()
|
||||
test_fashion_mnist_pk_sampler()
|
||||
test_fashion_mnist_sequential_sampler()
|
||||
test_fashion_mnist_exception()
|
||||
test_fashion_mnist_visualize(plot=True)
|
||||
test_fashion_mnist_usage()
|
Loading…
Reference in New Issue