[feat][assistant][I3J6VL] add new data operator USPS

This commit is contained in:
ckczzj 2021-06-02 13:56:12 +08:00
parent 72fb9207f8
commit d7c702e9d4
18 changed files with 1553 additions and 1 deletions

View File

@ -102,6 +102,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#endif
@ -1198,6 +1199,14 @@ TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_f
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
USPSDataset::USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<USPSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
bool decode, const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,

View File

@ -44,6 +44,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#endif
@ -286,6 +287,18 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
"to create an USPSNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, nullptr);
THROW_IF_ERROR(usps->ValidateParams());
return usps;
}));
}));
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(py::init([](std::string dataset_dir, std::string task, std::string usage,

View File

@ -14,6 +14,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
clue_op.cc
csv_op.cc
album_op.cc
usps_op.cc
mappable_leaf_op.cc
nonmappable_leaf_op.cc
cityscapes_op.cc

View File

@ -0,0 +1,350 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
constexpr int64_t kUSPSImageHeight = 16;
constexpr int64_t kUSPSImageWidth = 16;
constexpr int64_t kUSPSImageChannel = 1;
constexpr int64_t kUSPSImageSize = kUSPSImageHeight * kUSPSImageWidth * kUSPSImageChannel;
USPSOp::USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
device_id),
usage_(usage),
dataset_dir_(dataset_dir),
data_schema_(std::move(data_schema)) {}
void USPSOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nUSPS directory: " << dataset_dir_
<< "\nUSPS usage: " << usage_ << "\n\n";
out << "\nData schema:\n";
out << *data_schema_ << "\n\n";
}
}
Status USPSOp::Init() {
RETURN_IF_NOT_OK(this->GetFiles());
RETURN_IF_NOT_OK(filename_index_->insert(data_files_list_));
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(data_files_list_.size() / num_workers_) + 1);
io_block_queues_.Init(num_workers_, safe_queue_size);
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}
Status USPSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connector_size = cfg->op_connector_size();
int32_t worker_connector_size = cfg->worker_connector_size();
const int64_t num_samples = 0;
const int32_t num_devices = 1;
const int32_t device_id = 0;
bool shuffle = false;
auto op = std::make_shared<USPSOp>(dir, usage, std::move(schema), num_workers, worker_connector_size, num_samples,
op_connector_size, shuffle, num_devices, device_id);
RETURN_IF_NOT_OK(op->Init());
// the logic of counting the number of samples
for (auto data_file : op->FileNames()) {
*count += op->CountRows(data_file);
}
return Status::OK();
}
int64_t USPSOp::CountRows(const std::string &data_file) {
std::ifstream data_file_reader;
data_file_reader.open(data_file, std::ios::in);
if (!data_file_reader.is_open()) {
MS_LOG(ERROR) << "Invalid file, failed to open file: " << data_file;
return 0;
}
std::string line;
int64_t count = 0;
while (std::getline(data_file_reader, line)) {
if (!line.empty()) {
count++;
}
}
data_file_reader.close();
return count;
}
Status USPSOp::GetFiles() {
auto real_dataset_dir = Common::GetRealPath(dataset_dir_);
CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(), "Get real path failed: " + dataset_dir_);
Path root_dir(real_dataset_dir.value());
const Path train_file_name("usps");
const Path test_file_name("usps.t");
bool use_train = false;
bool use_test = false;
if (usage_ == "train") {
use_train = true;
} else if (usage_ == "test") {
use_test = true;
} else if (usage_ == "all") {
use_train = true;
use_test = true;
}
if (use_train) {
Path train_path = root_dir / train_file_name;
CHECK_FAIL_RETURN_UNEXPECTED(train_path.Exists() && !train_path.IsDirectory(),
"Invalid file, failed to find USPS train data file: " + train_path.ToString());
data_files_list_.emplace_back(train_path.ToString());
MS_LOG(INFO) << "USPS operator found train data file " << train_path.ToString() << ".";
}
if (use_test) {
Path test_path = root_dir / test_file_name;
CHECK_FAIL_RETURN_UNEXPECTED(test_path.Exists() && !test_path.IsDirectory(),
"Invalid file, failed to find USPS test data file: " + test_path.ToString());
data_files_list_.emplace_back(test_path.ToString());
MS_LOG(INFO) << "USPS operator found test data file " << test_path.ToString() << ".";
}
return Status::OK();
}
Status USPSOp::LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream data_file_reader(data_file);
if (!data_file_reader.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + data_file);
}
int64_t rows_total = 0;
std::string line;
while (getline(data_file_reader, line)) {
if (line.empty()) {
continue;
}
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}
TensorRow tRow(1, nullptr);
tRow.setPath({data_file});
Status rc = LoadTensor(&line, &tRow);
if (rc.IsError()) {
data_file_reader.close();
return rc;
}
rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
if (rc.IsError()) {
data_file_reader.close();
return rc;
}
rows_total++;
}
data_file_reader.close();
return Status::OK();
}
Status USPSOp::LoadTensor(std::string *line, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(line);
RETURN_UNEXPECTED_IF_NULL(trow);
auto images_buffer = std::make_unique<unsigned char[]>(kUSPSImageSize);
auto labels_buffer = std::make_unique<uint32_t[]>(1);
if (images_buffer == nullptr || labels_buffer == nullptr) {
MS_LOG(ERROR) << "Failed to allocate memory for USPS buffer.";
RETURN_STATUS_UNEXPECTED("Failed to allocate memory for USPS buffer.");
}
RETURN_IF_NOT_OK(this->ParseLine(line, images_buffer, labels_buffer));
// create tensor
std::shared_ptr<Tensor> image, label;
TensorShape image_tensor_shape = TensorShape({kUSPSImageHeight, kUSPSImageWidth, kUSPSImageChannel});
auto pixels = &images_buffer[0];
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
reinterpret_cast<unsigned char *>(pixels), &image));
RETURN_IF_NOT_OK(Tensor::CreateScalar(labels_buffer[0], &label));
(*trow) = {std::move(image), std::move(label)};
return Status::OK();
}
Status USPSOp::ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
const std::unique_ptr<uint32_t[]> &labels_buffer) {
auto label = &labels_buffer[0];
auto pixels = &images_buffer[0];
size_t pos = 0;
int32_t split_num = 0;
while ((pos = line->find(" ")) != std::string::npos) {
split_num += 1;
std::string item = line->substr(0, pos);
if (split_num == 1) {
// the class label is 1~10 but we need 0~9
*label = static_cast<uint32_t>(std::stoi(item)) - 1;
} else {
size_t split_pos = item.find(":");
CHECK_FAIL_RETURN_UNEXPECTED(split_pos != std::string::npos, "Invalid data, USPS data file is corrupted.");
// check pixel index
CHECK_FAIL_RETURN_UNEXPECTED(std::stoi(item.substr(0, split_pos)) == (split_num - 1),
"Invalid data, USPS data file is corrupted.");
std::string pixel_str = item.substr(split_pos + 1, item.length() - split_pos);
// transform the real pixel value from [-1, 1] to the integers within [0, 255]
pixels[split_num - 2] = static_cast<uint8_t>((std::stof(pixel_str) + 1.0) / 2.0 * 255.0);
}
line->erase(0, pos + 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(split_num == (kUSPSImageSize + 1), "Invalid data, USPS data file is corrupted.");
return Status::OK();
}
Status USPSOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountRows(it.value());
filename_numrows_[it.value()] = count;
num_rows_ += count;
}
if (num_rows_ == 0) {
std::stringstream ss;
for (int i = 0; i < data_files_list_.size(); ++i) {
ss << " " << data_files_list_[i];
}
std::string file_list = ss.str();
RETURN_STATUS_UNEXPECTED("Invalid data, data file may not be suitable to read with USPSDataset API. Check file: " +
file_list);
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
Status USPSOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
int64_t start_offset = 0;
int64_t end_offset = 0;
bool finish = false;
while (!finish) {
std::vector<std::pair<std::string, int64_t>> file_index;
if (!i_keys.empty()) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
}
} else {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
}
}
for (auto file_info : file_index) {
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
auto ioBlock =
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
queue_index = (queue_index + 1) % num_workers_;
}
pre_count += filename_numrows_[file_info.first];
}
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
return Status::OK();
}
Status USPSOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,137 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore {
namespace dataset {
class USPSOp : public NonMappableLeafOp {
public:
// Constructor.
// @param const std::string &dataset_dir - dir directory of USPS data file.
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
// @param std::unique_ptr<DataSchema> data_schema - the schema of the USPS dataset.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param worker_connector_size - size of each internal queue.
// @param num_samples - number of samples to read.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param shuffle_files - whether to shuffle the files before reading data.
// @param num_devices - number of devices.
// @param device_id - device id.
USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Destructor.
~USPSOp() = default;
// Op name getter.
// @return std::string - Name of the current Op.
std::string Name() const override { return "USPSOp"; }
// A print method typically used for debugging.
// @param std::ostream &out - out stream.
// @param bool show_all - whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
// Instantiates the internal queues and connectors
// @return Status - the error code returned.
Status Init() override;
// Function to count the number of samples in the USPS dataset.
// @param const std::string &dir - path to the USPS directory.
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
// @return Status - the error coed returned.
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
// File names getter.
// @return Vector of the input file names.
std::vector<std::string> FileNames() { return data_files_list_; }
private:
// Function to count the number of samples in one data file.
// @param const std::string &data_file - path to the data file.
// @return int64_t - the count result.
int64_t CountRows(const std::string &data_file);
// Reads a data file and loads the data into multiple TensorRows.
// @param data_file - the data file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param trow - image & label read into this tensor row.
// @return Status - the error code returned.
Status LoadTensor(std::string *line, TensorRow *trow);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard() override;
// Fill the IOBlockQueue.
// @param i_keys - keys of file to fill to the IOBlockQueue.
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Get all files in the dataset_dir_.
// @return Status - The status code returned.
Status GetFiles();
// Parse a line to image and label.
// @param line - the content of the row.
// @param images_buffer - image destination.
// @param labels_buffer - label destination.
// @return Status - the status code returned.
Status ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
const std::unique_ptr<uint32_t[]> &labels_buffer);
// Private function for computing the assignment of the column name map.
// @return Status - the error code returned.
Status ComputeColMap() override;
const std::string usage_; // can be "all", "train" or "test".
std::string dataset_dir_; // directory of data files.
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::string> data_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_

View File

@ -91,6 +91,7 @@ constexpr char kMnistNode[] = "MnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kUSPSNode[] = "USPSDataset";
constexpr char kVOCNode[] = "VOCDataset";
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,

View File

@ -19,6 +19,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
random_node.cc
text_file_node.cc
tf_record_node.cc
usps_node.cc
voc_node.cc
)

View File

@ -0,0 +1,173 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
USPSNode::USPSNode(std::string dataset_dir, std::string usage, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> USPSNode::Copy() {
auto node = std::make_shared<USPSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void USPSNode::Print(std::ostream &out) const {
out << (Name() + "(dataset dir:" + dataset_dir_ + ", usage:" + usage_ +
", num_shards:" + std::to_string(num_shards_) + ", shard_id:" + std::to_string(shard_id_) +
", num_samples:" + std::to_string(num_samples_) + ")");
}
Status USPSNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("USPSNode", usage_, {"train", "test", "all"}));
if (num_samples_ < 0) {
std::string err_msg = "USPSNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSNode", num_shards_, shard_id_));
return Status::OK();
}
Status USPSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
auto op = std::make_shared<USPSOp>(dataset_dir_, usage_, std::move(schema), num_workers_, worker_connector_size_,
num_samples_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(op->Init());
// If a global shuffle is used for USPS, it will inject a shuffle op over the USPS.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset USPS's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node
Status USPSNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size
Status USPSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status USPSNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// USPS by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status USPSNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults so
// that this USPS node will produce the full set of data into the cache.
Status USPSNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,122 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class USPSNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
USPSNode(std::string dataset_dir, std::string usage, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~USPSNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kUSPSNode; }
/// \brief Print the description.
/// \param out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Getter functions.
const std::string &Usage() const { return usage_; }
/// \brief Getter functions.
int32_t NumSamples() const { return num_samples_; }
/// \brief Getter functions.
int32_t NumShards() const { return num_shards_; }
/// \brief Getter functions.
int32_t ShardId() const { return shard_id_; }
/// \brief Getter functions.
ShuffleMode Shuffle() const { return shuffle_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief USPS by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults
/// so that this USPS node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
private:
std::string dataset_dir_;
std::string usage_;
int32_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_

View File

@ -2386,6 +2386,56 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase
return ds;
}
/// \class USPSDataset
/// \brief A source dataset that reads and parses USPS datasets.
class USPSDataset : public Dataset {
public:
/// \brief Constructor of USPSDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0 means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
explicit USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
/// Destructor of USPSDataset.
~USPSDataset() = default;
};
/// \brief Function to create a USPSDataset.
/// \notes The generated dataset has two columns ["image", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0 means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \return Shared pointer to the current USPSDataset.
inline std::shared_ptr<USPSDataset> USPS(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<USPSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
shard_id, cache);
}
/// \class VOCDataset
/// \brief A source dataset for reading and parsing VOC dataset.
class VOCDataset : public Dataset {

View File

@ -46,6 +46,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class RandomDataDataset;
friend class TextFileDataset;
friend class TFRecordDataset;
friend class USPSDataset;
friend class VOCDataset;
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

View File

@ -64,7 +64,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -4712,6 +4712,103 @@ class Schema:
return schema_obj.cpp_schema.get_num_rows()
class USPSDataset(SourceDataset):
"""
A source dataset for reading and parsing the USPS dataset.
The generated dataset has two columns: :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all". "train" will read from 7,291
train samples, "test" will read from 2,007 test samples, "all" will read from all 9,298 samples.
(default=None, will read all samples)
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will read all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, will use value set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir is not valid or does not exist or does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If usage is invalid.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> usps_dataset_dir = "/path/to/usps_dataset_directory"
>>>
>>> # Read 3 samples from USPS dataset
>>> dataset = ds.USPSDataset(dataset_dir=usps_dataset_dir, num_samples=3)
>>>
>>> # Note: In USPS dataset, each dictionary has keys "image" and "label"
About USPS dataset:
USPS is a digit dataset automatically scanned from envelopes by the U.S. Postal Service
containing a total of 9,298 16×16 pixel grayscale samples.
The images are centered, normalized and show a broad range of font styles.
Here is the original USPS dataset structure.
You can download and unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
usps_dataset_dir
usps
usps.t
Citation:
.. code-block::
@article{hull1994database,
title={A database for handwritten text recognition research},
author={Hull, Jonathan J.},
journal={IEEE Transactions on pattern analysis and machine intelligence},
volume={16},
number={5},
pages={550--554},
year={1994},
publisher={IEEE}
}
"""
@check_usps_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.USPSNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class VOCDataset(MappableDataset):
"""
A source dataset for reading and parsing VOC dataset.

View File

@ -151,6 +151,33 @@ def check_tfrecorddataset(method):
return new_method
def check_usps_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_vocdataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""

View File

@ -31,6 +31,7 @@ SET(DE_UT_SRCS
c_api_dataset_save.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_usps_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_epoch_ctrl_test.cc

View File

@ -0,0 +1,255 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestUSPSTrainDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDataset.";
// Create a USPS Train Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestUSPSTestDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTestDataset.";
// Create a USPS Test Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "test");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestUSPSAllDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSAllDataset.";
// Create a USPS Test Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "all");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestUSPSDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetWithPipeline.";
// Create two USPS Train Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds1 = USPS(folder_path, "train");
std::shared_ptr<Dataset> ds2 = USPS(folder_path, "train");
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"image", "label"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestGetUSPSDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetUSPSTrainDatasetSize.";
// Create a USPS Train Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
}
TEST_F(MindDataTestPipeline, TestUSPSDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetGetters.";
// Create a USPS Train Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"image", "label"};
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<16,16,1>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 3);
}
TEST_F(MindDataTestPipeline, TestUSPSDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetFail.";
// Create a USPS Dataset
std::shared_ptr<Dataset> ds = USPS("", "train");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid USPS input
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestUSPSDatasetWithInvalidUsageFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetWithInvalidUsageFail.";
// Create a USPS Dataset
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
std::shared_ptr<Dataset> ds = USPS(folder_path, "validation");
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid USPS input, validation is not a valid usage
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,3 @@
1 1:0.600000 2:-0.81176 3:0.937254 4:0.160784 5:-0.25490 6:0.615686 7:-0.02745 8:0.866666 9:-0.03529 10:0.137254 11:-0.52941 12:-0.20784 13:0.411764 14:-0.71764 15:-0.59215 16:-0.08235 17:-0.6 18:0.505882 19:-0.09803 20:-0.22352 21:-0.30980 22:0.725490 23:0.294117 24:-0.70196 25:0.929411 26:-0.30196 27:0.247058 28:0.035294 29:-0.46666 30:-0.86666 31:-0.35686 32:-0.54509 33:0.992156 34:-0.04313 35:-0.84313 36:-0.39607 37:-0.16078 38:-0.72549 39:-1.0 40:-0.10588 41:0.090196 42:-0.29411 43:-0.42745 44:-0.63137 45:-0.03529 46:-0.44313 47:-0.16078 48:-0.62352 49:0.254901 50:0.168627 51:0.545098 52:-0.37254 53:-0.89019 54:-0.63137 55:-0.18431 56:0.631372 57:0.749019 58:-0.78039 59:0.058823 60:-0.45882 61:-0.47450 62:-0.21568 63:0.890196 64:-0.69411 65:0.976470 66:-0.81176 67:0.419607 68:-0.95294 69:-0.56078 70:-0.11372 71:0.145098 72:0.105882 73:0.537254 74:-0.73333 75:-0.30196 76:-0.56862 77:0.811764 78:0.552941 79:0.152941 80:-0.88235 81:-0.69411 82:0.647058 83:-0.11372 84:-0.99215 85:-0.85098 86:-0.43529 87:0.278431 88:-0.83529 89:0.427450 90:-0.30196 91:0.725490 92:0.929411 93:0.239215 94:0.176470 95:0.709803 96:-0.00392 97:0.317647 98:-0.17647 99:0.098039 100:-0.78039 101:-0.21568 102:0.976470 103:0.670588 104:0.349019 105:0.623529 106:0.631372 107:0.529411 108:-0.53725 109:0.694117 110:0.505882 111:-0.70196 112:0.592156 113:0.458823 114:-0.67058 115:-0.37254 116:0.450980 117:-0.67058 118:0.482352 119:-0.62352 120:0.741176 121:0.552941 122:0.105882 123:-0.12941 124:0.278431 125:-0.36470 126:0.827450 127:0.537254 128:0.827450 129:-0.49803 130:0.341176 131:-0.74901 132:-0.85098 133:0.247058 134:-0.75686 135:0.678431 136:0.670588 137:-1.0 138:-0.23137 139:-0.30196 140:0.223529 141:0.325490 142:-0.63137 143:-0.99215 144:-0.45882 145:0.474509 146:0.341176 147:0.552941 148:-0.11372 149:-0.34117 150:-0.86666 151:-0.56862 152:0.168627 153:-0.11372 154:0.098039 155:-0.94509 156:-0.01960 157:-0.45098 158:0.168627 159:-0.94509 160:-0.84313 161:-0.44313 162:-0.05882 163:-0.05882 164:-0.49803 165:0.105882 166:-0.96862 167:0.725490 168:-0.30196 169:-0.21568 170:0.803921 171:0.647058 172:-0.46666 173:-0.52941 174:0.796078 175:0.654901 176:-0.81176 177:-0.30196 178:0.811764 179:-0.63137 180:-0.41176 181:-0.12156 182:-1.0 183:-0.59215 184:0.631372 185:-0.41960 186:0.937254 187:0.396078 188:0.137254 189:-0.50588 190:0.443137 191:0.474509 192:0.121568 193:0.074509 194:-0.06666 195:-0.63921 196:0.286274 197:0.019607 198:-0.14509 199:-0.56862 200:0.537254 201:0.333333 202:0.027450 203:-0.95294 204:0.678431 205:0.474509 206:0.027450 207:0.349019 208:-0.91372 209:0.772549 210:0.223529 211:-0.96078 212:-0.38823 213:-0.50588 214:0.764705 215:0.490196 216:0.458823 217:0.945098 218:-0.17647 219:0.560784 220:-0.01960 221:0.474509 222:-0.84313 223:0.043137 224:0.615686 225:0.388235 226:0.082352 227:-0.45098 228:-0.97647 229:-0.58431 230:0.905882 231:0.372549 232:-0.23137 233:-0.07450 234:0.262745 235:-0.70980 236:0.733333 237:-0.51372 238:-0.63921 239:0.443137 240:0.678431 241:-0.64705 242:0.631372 243:-0.98431 244:0.411764 245:0.019607 246:0.239215 247:0.380392 248:-0.41960 249:0.678431 250:-0.81960 251:0.364705 252:-0.93725 253:-0.05882 254:0.623529 255:0.160784 256:-0.19999
7 1:0.294117 2:0.576470 3:-0.98431 4:-0.19999 5:-0.90588 6:0.074509 7:-0.54509 8:-0.98431 9:-0.67843 10:-0.04313 11:-0.56862 12:0.733333 13:-0.85098 14:0.082352 15:-0.50588 16:-0.37254 17:0.160784 18:0.772549 19:-0.69411 20:-0.07450 21:-0.90588 22:0.443137 23:-0.25490 24:0.678431 25:-0.38039 26:-0.87450 27:0.286274 28:-0.67058 29:0.623529 30:0.639215 31:-0.86666 32:0.545098 33:-0.22352 34:-0.05098 35:-0.81176 36:-0.15294 37:0.380392 38:0.058823 39:-0.26274 40:-0.85882 41:0.945098 42:0.945098 43:0.231372 44:0.356862 45:0.874509 46:0.654901 47:0.647058 48:-0.80392 49:0.396078 50:0.129411 51:-0.99215 52:0.945098 53:-0.01176 54:0.199999 55:-0.34901 56:-0.58431 57:0.301960 58:-0.68627 59:0.639215 60:0.176470 61:0.011764 62:0.796078 63:0.513725 64:-0.30980 65:0.615686 66:0.356862 67:0.545098 68:-0.85882 69:0.199999 70:-0.65490 71:-0.15294 72:-0.09803 73:0.090196 74:-0.11372 75:-0.55294 76:0.294117 77:-0.38823 78:0.184313 79:0.458823 80:0.247058 81:0.945098 82:0.976470 83:-0.16078 84:0.380392 85:-0.51372 86:0.984313 87:0.945098 88:0.215686 89:0.207843 90:0.301960 91:-0.67843 92:-0.34117 93:0.741176 94:0.372549 95:-0.70196 96:0.066666 97:0.890196 98:0.600000 99:-0.15294 100:-0.57647 101:0.678431 102:-0.25490 103:0.231372 104:-0.10588 105:0.788235 106:0.796078 107:0.301960 108:0.121568 109:0.356862 110:-0.12941 111:0.231372 112:0.066666 113:0.709803 114:0.137254 115:0.905882 116:-0.82745 117:0.443137 118:0.929411 119:-0.05098 120:0.537254 121:-0.53725 122:-0.34901 123:0.850980 124:0.576470 125:0.670588 126:0.380392 127:-0.09803 128:0.952941 129:0.372549 130:-0.96862 131:0.882352 132:0.231372 133:-0.59215 134:0.341176 135:0.694117 136:0.192156 137:-0.03529 138:-0.60784 139:0.450980 140:0.372549 141:-0.10588 142:-0.44313 143:0.741176 144:-0.35686 145:-0.41176 146:0.411764 147:-0.05098 148:-0.57647 149:-0.37254 150:-0.36470 151:0.082352 152:-0.05882 153:-0.29411 154:0.788235 155:0.576470 156:-0.79607 157:0.254901 158:-0.09019 159:0.427450 160:0.301960 161:-0.10588 162:0.952941 163:0.466666 164:-0.59215 165:-0.95294 166:0.435294 167:-0.12156 168:-0.70980 169:0.262745 170:0.960784 171:-0.78039 172:-0.63921 173:-1.0 174:-0.53725 175:0.678431 176:-0.41176 177:-0.94509 178:0.129411 179:0.372549 180:0.521568 181:-0.75686 182:0.615686 183:-0.04313 184:0.262745 185:-0.37254 186:0.145098 187:0.537254 188:-0.12156 189:-0.01176 190:0.560784 191:0.011764 192:0.160784 193:0.647058 194:-0.57647 195:0.835294 196:0.905882 197:0.278431 198:0.811764 199:-0.86666 200:-0.78823 201:-0.44313 202:0.600000 203:0.505882 204:0.254901 205:0.215686 206:-0.33333 207:-0.65490 208:-0.33333 209:0.317647 210:-0.00392 211:-0.60784 212:-0.63921 213:0.890196 214:-0.64705 215:-0.56078 216:-0.59215 217:0.654901 218:0.168627 219:0.537254 220:-0.76470 221:0.301960 222:-0.33333 223:-0.74117 224:0.098039 225:0.152941 226:0.694117 227:-0.18431 228:0.843137 229:0.592156 230:-0.38039 231:-0.30980 232:0.160784 233:0.262745 234:-0.96862 235:-0.41960 236:0.380392 237:0.623529 238:0.858823 239:-0.24705 240:-0.38823 241:0.356862 242:0.576470 243:-0.38823 244:-0.60784 245:-0.24705 246:-0.54509 247:0.317647 248:0.341176 249:0.913725 250:0.003921 251:-0.34117 252:0.513725 253:-0.96078 254:-0.84313 255:-0.93725 256:-0.92941
1 1:0.411764 2:0.929411 3:-0.03529 4:0.231372 5:-0.30980 6:0.286274 7:-0.58431 8:0.937254 9:0.388235 10:-0.89803 11:-0.16862 12:0.756862 13:-0.67843 14:-0.89019 15:0.749019 16:0.074509 17:-0.25490 18:0.741176 19:-0.45098 20:-0.22352 21:0.788235 22:0.576470 23:-0.59215 24:-0.51372 25:0.560784 26:-0.97647 27:0.670588 28:-0.27843 29:-0.64705 30:0.984313 31:-0.22352 32:0.027450 33:-0.96862 34:-0.78039 35:0.694117 36:0.349019 37:-0.53725 38:-0.37254 39:-0.49019 40:-0.68627 41:0.325490 42:0.992156 43:0.098039 44:0.270588 45:0.882352 46:0.043137 47:-0.39607 48:-0.43529 49:-0.75686 50:-0.21568 51:-0.86666 52:-0.67058 53:-0.62352 54:0.607843 55:0.552941 56:0.380392 57:0.568627 58:-0.31764 59:0.639215 60:0.607843 61:-0.37254 62:0.176470 63:0.098039 64:-0.53725 65:-0.71764 66:-0.30196 67:-0.44313 68:0.568627 69:-0.85098 70:0.505882 71:0.537254 72:-0.41960 73:0.474509 74:0.254901 75:-0.16862 76:-0.44313 77:-0.73333 78:-0.08235 79:0.537254 80:-0.58431 81:-0.96862 82:0.733333 83:-0.01176 84:-0.87450 85:0.701960 86:0.890196 87:-0.61568 88:0.694117 89:-0.83529 90:-0.34117 91:-0.59215 92:0.952941 93:0.239215 94:-0.87450 95:0.803921 96:0.905882 97:0.545098 98:-0.32549 99:-0.38823 100:1.0 101:-0.50588 102:0.835294 103:0.254901 104:-0.70196 105:-0.77254 106:-0.55294 107:0.474509 108:-0.78039 109:-0.88235 110:0.388235 111:-0.52941 112:0.443137 113:0.129411 114:-0.89019 115:0.145098 116:-0.16862 117:0.568627 118:-0.27843 119:-0.23137 120:-0.05882 121:0.003921 122:0.662745 123:-0.87450 124:0.474509 125:0.639215 126:-0.97647 127:0.701960 128:0.247058 129:0.545098 130:-0.70980 131:0.913725 132:-0.63921 133:-0.66274 134:-0.16862 135:-0.63921 136:-0.05098 137:-0.31764 138:-0.54509 139:-0.83529 140:-0.83529 141:0.215686 142:-0.28627 143:-0.39607 144:-0.06666 145:-0.81176 146:0.145098 147:-0.12941 148:-0.44313 149:-0.10588 150:0.199999 151:0.537254 152:-0.36470 153:-0.48235 154:-0.77254 155:-0.96862 156:-0.67843 157:0.850980 158:-0.45882 159:-0.99215 160:0.231372 161:-0.54509 162:-0.77254 163:0.937254 164:-0.12941 165:-0.32549 166:-0.01176 167:-0.40392 168:0.607843 169:0.835294 170:-0.83529 171:0.082352 172:0.450980 173:-0.85098 174:0.207843 175:0.239215 176:0.827450 177:0.490196 178:-0.68627 179:-0.20784 180:-0.46666 181:0.474509 182:-0.97647 183:-0.30980 184:0.145098 185:0.717647 186:0.411764 187:-0.67843 188:-0.40392 189:-0.90588 190:-0.87450 191:0.113725 192:0.333333 193:0.952941 194:-0.58431 195:-0.83529 196:-0.32549 197:0.882352 198:0.301960 199:-0.82745 200:-0.53725 201:0.952941 202:0.600000 203:0.921568 204:-0.40392 205:0.945098 206:-0.31764 207:0.121568 208:-0.36470 209:-0.19999 210:-0.01176 211:0.011764 212:-0.31764 213:-0.45098 214:-0.94509 215:-0.05098 216:-0.42745 217:0.749019 218:-0.64705 219:-0.81176 220:0.505882 221:-0.66274 222:0.623529 223:0.882352 224:-0.18431 225:0.835294 226:0.968627 227:0.050980 228:0.160784 229:-0.41960 230:0.662745 231:-0.26274 232:-0.64705 233:-0.63137 234:-0.89019 235:0.694117 236:-0.56862 237:0.819607 238:0.686274 239:-0.18431 240:-0.57647 241:-0.6 242:0.976470 243:0.160784 244:-0.68627 245:0.223529 246:0.615686 247:0.074509 248:-0.64705 249:0.717647 250:0.921568 251:0.058823 252:0.349019 253:-0.31764 254:-0.13725 255:-0.49019 256:-0.57647

View File

@ -0,0 +1,3 @@
8 1:-0.41176 2:-0.22352 3:0.129411 4:0.380392 5:0.545098 6:-0.44313 7:-0.78039 8:0.725490 9:-0.36470 10:0.529411 11:0.309803 12:-0.34117 13:-0.27058 14:0.521568 15:0.780392 16:0.890196 17:-0.76470 18:0.356862 19:-0.09019 20:-0.49019 21:0.694117 22:-0.09019 23:0.301960 24:-0.44313 25:-0.87450 26:0.709803 27:0.749019 28:-0.89019 29:0.223529 30:-0.89019 31:0.490196 32:-0.85098 33:0.694117 34:0.090196 35:0.035294 36:0.788235 37:0.333333 38:0.419607 39:0.678431 40:-0.74117 41:0.137254 42:0.756862 43:0.576470 44:-0.97647 45:-0.38823 46:-0.36470 47:-0.67843 48:0.678431 49:-0.58431 50:-0.58431 51:0.309803 52:-0.52941 53:-0.67843 54:0.403921 55:0.858823 56:-0.25490 57:0.717647 58:-0.47450 59:-0.92941 60:0.529411 61:-0.37254 62:-0.89019 63:-0.51372 64:-0.91372 65:0.811764 66:-0.78039 67:1.0 68:0.600000 69:-0.20784 70:0.239215 71:0.027450 72:0.388235 73:-0.38823 74:-0.81960 75:-0.24705 76:-0.25490 77:-0.09803 78:0.568627 79:-0.40392 80:0.882352 81:-0.44313 82:0.199999 83:0.568627 84:-0.88235 85:0.066666 86:0.035294 87:0.490196 88:0.701960 89:-0.42745 90:0.560784 91:-1.0 92:-0.96078 93:0.631372 94:0.552941 95:0.960784 96:-0.92156 97:0.913725 98:-0.04313 99:0.843137 100:0.741176 101:0.364705 102:0.654901 103:0.529411 104:0.709803 105:-0.69411 106:0.670588 107:0.976470 108:-0.28627 109:0.482352 110:0.529411 111:0.670588 112:-0.46666 113:0.443137 114:-0.52156 115:0.168627 116:-0.05882 117:0.145098 118:-0.05098 119:0.811764 120:-0.91372 121:0.113725 122:0.725490 123:0.427450 124:-0.00392 125:0.725490 126:0.505882 127:0.333333 128:-0.22352 129:0.490196 130:-0.78039 131:0.458823 132:-0.11372 133:0.137254 134:0.615686 135:-0.54509 136:-0.81176 137:-0.23921 138:-0.62352 139:-0.87450 140:0.560784 141:-0.92941 142:0.858823 143:0.364705 144:0.035294 145:0.176470 146:0.921568 147:0.552941 148:-0.60784 149:0.396078 150:0.168627 151:0.254901 152:-0.40392 153:-0.27058 154:0.396078 155:-0.29411 156:0.011764 157:0.003921 158:-0.75686 159:-0.10588 160:0.741176 161:-0.06666 162:-0.97647 163:-0.78039 164:0.356862 165:0.341176 166:-0.05882 167:0.576470 168:-0.6 169:0.647058 170:-0.80392 171:-0.12941 172:0.513725 173:0.701960 174:0.937254 175:-0.49019 176:-0.70980 177:0.960784 178:0.474509 179:-0.17647 180:-0.37254 181:-0.23137 182:-0.06666 183:-0.37254 184:-0.76470 185:0.474509 186:0.168627 187:-0.71764 188:-0.48235 189:0.505882 190:-0.13725 191:0.623529 192:0.333333 193:-0.43529 194:-0.34901 195:-0.19215 196:0.717647 197:0.513725 198:-0.25490 199:-0.30196 200:0.098039 201:0.937254 202:-0.63921 203:-0.37254 204:-0.99215 205:0.301960 206:0.482352 207:-0.21568 208:0.050980 209:0.419607 210:0.756862 211:0.701960 212:0.294117 213:-0.45882 214:0.341176 215:0.992156 216:0.003921 217:0.411764 218:-0.33333 219:0.427450 220:-0.97647 221:0.545098 222:0.341176 223:-0.90588 224:0.811764 225:0.176470 226:-0.78823 227:0.050980 228:0.733333 229:0.286274 230:-0.78039 231:-0.89803 232:0.490196 233:0.788235 234:-0.52156 235:-0.78039 236:0.035294 237:0.600000 238:-0.12156 239:-0.09803 240:-0.35686 241:-0.82745 242:-0.05098 243:-0.67058 244:-0.74901 245:-0.30980 246:-0.94509 247:0.419607 248:-0.97647 249:-0.52941 250:-0.56078 251:-0.96078 252:-0.83529 253:0.905882 254:-0.32549 255:-0.68627 256:0.215686
7 1:-0.33333 2:0.380392 3:0.866666 4:-0.58431 5:0.403921 6:0.145098 7:0.074509 8:0.027450 9:-0.31764 10:0.239215 11:0.333333 12:-0.38823 13:0.662745 14:-0.56862 15:-0.27058 16:0.505882 17:-0.79607 18:0.898039 19:-0.45882 20:-0.98431 21:0.152941 22:-0.03529 23:0.270588 24:0.866666 25:-0.95294 26:0.803921 27:0.694117 28:0.450980 29:0.733333 30:-0.15294 31:0.286274 32:-0.95294 33:-0.93725 34:0.654901 35:-0.25490 36:-0.81176 37:0.678431 38:0.545098 39:0.035294 40:-0.05098 41:-0.46666 42:-0.05882 43:0.764705 44:0.545098 45:0.866666 46:0.631372 47:0.356862 48:-1.0 49:-0.92941 50:-0.44313 51:-0.30196 52:-0.02745 53:0.356862 54:-0.75686 55:0.741176 56:-0.54509 57:-0.73333 58:-0.52941 59:-0.70980 60:-0.85882 61:0.929411 62:0.223529 63:0.654901 64:-0.45098 65:0.552941 66:-0.03529 67:0.498039 68:-0.04313 69:0.600000 70:0.490196 71:0.403921 72:1.0 73:0.741176 74:-0.86666 75:-0.52941 76:-0.17647 77:-0.05098 78:-0.96862 79:0.105882 80:-0.34117 81:0.403921 82:-0.56078 83:0.796078 84:-0.05882 85:0.521568 86:-0.41960 87:0.262745 88:-0.49803 89:0.035294 90:-0.09803 91:0.356862 92:-0.92156 93:-0.70196 94:-0.59215 95:-0.17647 96:-0.96862 97:0.349019 98:-0.23137 99:0.984313 100:0.207843 101:0.011764 102:-0.19999 103:-0.83529 104:0.121568 105:0.058823 106:-0.19215 107:0.670588 108:-0.34117 109:0.890196 110:0.921568 111:-0.65490 112:0.772549 113:0.741176 114:-0.29411 115:-0.64705 116:-0.49803 117:0.058823 118:-0.6 119:0.035294 120:-0.53725 121:-0.97647 122:0.811764 123:0.458823 124:0.945098 125:0.262745 126:0.152941 127:-0.06666 128:0.788235 129:-0.89803 130:-0.26274 131:0.466666 132:0.043137 133:0.607843 134:0.984313 135:-0.22352 136:0.874509 137:0.223529 138:0.835294 139:-0.75686 140:0.858823 141:-0.72549 142:0.254901 143:0.262745 144:0.105882 145:0.192156 146:0.772549 147:-0.19999 148:-0.28627 149:-0.40392 150:-0.96862 151:-0.47450 152:-0.34117 153:0.662745 154:0.615686 155:-0.46666 156:-0.26274 157:0.529411 158:0.741176 159:-0.87450 160:-0.41960 161:0.921568 162:-0.28627 163:0.247058 164:-0.12941 165:0.513725 166:-0.84313 167:-0.78039 168:0.301960 169:0.121568 170:0.286274 171:0.129411 172:0.623529 173:0.976470 174:-0.39607 175:0.513725 176:-0.63921 177:0.717647 178:0.388235 179:0.443137 180:0.184313 181:-0.49019 182:-0.95294 183:-0.53725 184:-0.29411 185:0.992156 186:0.003921 187:0.223529 188:-0.44313 189:0.443137 190:-0.69411 191:-0.93725 192:-0.54509 193:-0.96078 194:-0.52156 195:-0.06666 196:-0.80392 197:0.913725 198:0.380392 199:-0.35686 200:0.309803 201:0.145098 202:0.913725 203:0.741176 204:0.262745 205:0.952941 206:-0.66274 207:0.827450 208:-0.38823 209:0.105882 210:0.003921 211:0.325490 212:-0.41960 213:0.615686 214:0.521568 215:0.686274 216:-0.48235 217:-0.82745 218:-0.09803 219:-0.19215 220:0.639215 221:-0.09803 222:0.733333 223:-0.35686 224:0.286274 225:0.498039 226:0.152941 227:0.749019 228:0.294117 229:0.505882 230:-0.30196 231:-0.42745 232:0.639215 233:0.137254 234:0.223529 235:-0.6 236:-0.82745 237:0.349019 238:0.388235 239:0.631372 240:-0.41176 241:-0.70980 242:0.090196 243:-0.99215 244:0.270588 245:0.011764 246:0.968627 247:0.850980 248:-0.16862 249:0.686274 250:0.537254 251:-0.37254 252:0.647058 253:0.152941 254:0.717647 255:-0.45098 256:0.623529
6 1:0.443137 2:-0.43529 3:0.223529 4:-0.88235 5:0.694117 6:0.435294 7:-0.41960 8:-0.03529 9:0.552941 10:0.709803 11:-0.58431 12:0.701960 13:-0.23137 14:-0.03529 15:-0.55294 16:1.0 17:0.819607 18:-0.19215 19:0.482352 20:0.058823 21:-0.68627 22:-0.81960 23:0.121568 24:0.709803 25:-0.99215 26:0.474509 27:0.905882 28:-0.06666 29:-0.85098 30:-0.90588 31:-0.67843 32:-0.13725 33:-0.56078 34:-0.36470 35:0.827450 36:0.396078 37:0.388235 38:0.803921 39:0.647058 40:-0.50588 41:-0.84313 42:0.466666 43:-0.18431 44:0.207843 45:0.694117 46:0.427450 47:0.827450 48:0.890196 49:0.129411 50:-0.02745 51:-0.20784 52:0.090196 53:-0.85882 54:-0.00392 55:0.105882 56:-0.79607 57:-0.20784 58:0.874509 59:0.380392 60:-0.55294 61:0.098039 62:0.113725 63:0.121568 64:-0.10588 65:0.301960 66:0.866666 67:0.654901 68:0.929411 69:-0.45098 70:0.984313 71:-0.23921 72:0.027450 73:0.223529 74:-0.74117 75:0.709803 76:-0.63137 77:0.309803 78:-0.13725 79:-0.31764 80:1.0 81:0.647058 82:-1.0 83:-0.03529 84:-0.82745 85:0.937254 86:-0.74901 87:0.945098 88:-0.95294 89:0.168627 90:0.121568 91:0.184313 92:-0.94509 93:0.756862 94:0.631372 95:-0.32549 96:-0.27843 97:0.082352 98:-0.12156 99:0.372549 100:0.725490 101:0.176470 102:-0.63921 103:-0.38039 104:-0.23137 105:0.341176 106:0.050980 107:-0.75686 108:-0.72549 109:0.286274 110:-0.96862 111:0.850980 112:0.223529 113:-0.93725 114:-0.79607 115:0.027450 116:-0.05098 117:-0.51372 118:0.560784 119:0.490196 120:0.223529 121:0.764705 122:-0.90588 123:-0.37254 124:0.403921 125:0.521568 126:-0.33333 127:-0.65490 128:-0.63137 129:-0.00392 130:0.552941 131:0.419607 132:0.733333 133:-0.78823 134:-0.71764 135:-0.52156 136:-0.38823 137:-0.21568 138:-0.69411 139:0.850980 140:-0.73333 141:0.223529 142:-0.54509 143:0.184313 144:0.403921 145:0.207843 146:-0.14509 147:0.764705 148:0.466666 149:-0.51372 150:0.709803 151:0.992156 152:-0.05882 153:-0.49803 154:0.647058 155:-0.59215 156:-0.98431 157:0.890196 158:-0.74117 159:0.662745 160:0.764705 161:0.192156 162:-0.6 163:0.027450 164:0.294117 165:0.788235 166:-0.39607 167:-0.12941 168:-0.94509 169:0.278431 170:-0.60784 171:0.968627 172:0.199999 173:-0.92941 174:0.984313 175:-0.50588 176:0.262745 177:0.247058 178:-0.46666 179:-0.20784 180:0.199999 181:-0.85882 182:-0.48235 183:0.450980 184:-0.79607 185:-0.41176 186:0.992156 187:-0.92941 188:-0.30980 189:-0.18431 190:0.811764 191:-0.88235 192:0.254901 193:-0.62352 194:0.082352 195:0.019607 196:0.349019 197:0.458823 198:-0.01960 199:-0.85098 200:-0.15294 201:0.160784 202:-0.36470 203:-0.16078 204:0.035294 205:-0.63921 206:0.898039 207:-0.60784 208:0.788235 209:0.803921 210:-0.28627 211:0.662745 212:-0.06666 213:-0.74117 214:-0.81960 215:0.992156 216:-0.89803 217:0.035294 218:-0.43529 219:-0.69411 220:-0.55294 221:0.450980 222:-0.59215 223:0.890196 224:0.780392 225:0.137254 226:0.545098 227:0.890196 228:0.074509 229:0.419607 230:0.372549 231:0.623529 232:-0.81960 233:0.568627 234:0.215686 235:0.552941 236:-0.43529 237:0.356862 238:-0.12941 239:0.929411 240:0.819607 241:-0.37254 242:-0.81960 243:0.976470 244:-0.09019 245:0.145098 246:-0.12941 247:0.780392 248:0.129411 249:-0.83529 250:-0.60784 251:-0.57647 252:-0.20784 253:0.937254 254:0.662745 255:-0.41960 256:-1.0

View File

@ -0,0 +1,308 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test USPS dataset operators
"""
import os
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testUSPSDataset"
WRONG_DIR = "../data/dataset/testMnistData"
def load_usps(path, usage):
"""
load USPS data
"""
assert usage in ["train", "test"]
if usage == "train":
data_path = os.path.realpath(os.path.join(path, "usps"))
elif usage == "test":
data_path = os.path.realpath(os.path.join(path, "usps.t"))
with open(data_path, 'r') as f:
raw_data = [line.split() for line in f.readlines()]
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
images = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16, 1))
images = ((cast(np.ndarray, images) + 1) / 2 * 255).astype(dtype=np.uint8)
labels = [int(d[0]) - 1 for d in raw_data]
return images, labels
def visualize_dataset(images, labels):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
plt.title(labels[i])
plt.show()
def test_usps_content_check():
"""
Validate USPSDataset image readings
"""
logger.info("Test USPSDataset Op with content check")
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=10, shuffle=False)
images, labels = load_usps(DATA_DIR, "train")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
for m in range(16):
for n in range(16):
assert (data["image"][m, n, 0] != 0 or images[i][m, n, 0] != 255) and \
(data["image"][m, n, 0] != 255 or images[i][m, n, 0] != 0)
assert (data["image"][m, n, 0] == images[i][m, n, 0]) or\
(data["image"][m, n, 0] == images[i][m, n, 0] + 1) or\
(data["image"][m, n, 0] + 1 == images[i][m, n, 0])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 3
test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
images, labels = load_usps(DATA_DIR, "test")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
for m in range(16):
for n in range(16):
if (data["image"][m, n, 0] == 0 and images[i][m, n, 0] == 255) or\
(data["image"][m, n, 0] == 255 and images[i][m, n, 0] == 0):
assert False
if (data["image"][m, n, 0] != images[i][m, n, 0]) and\
(data["image"][m, n, 0] != images[i][m, n, 0] + 1) and\
(data["image"][m, n, 0] + 1 != images[i][m, n, 0]):
assert False
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 3
def test_usps_basic():
"""
Validate USPSDataset
"""
logger.info("Test USPSDataset Op")
# case 1: test loading whole dataset
train_data = ds.USPSDataset(DATA_DIR, "train")
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
test_data = ds.USPSDataset(DATA_DIR, "test")
num_iter = 0
for _ in test_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
# case 2: test num_samples
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 2
# case 3: test repeat
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
train_data = train_data.repeat(5)
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 10
# case 4: test batch with drop_remainder=False
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
assert train_data.get_dataset_size() == 3
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False
assert train_data.get_batch_size() == 2
assert train_data.get_dataset_size() == 2
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 2
# case 5: test batch with drop_remainder=True
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
assert train_data.get_dataset_size() == 3
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped
assert train_data.get_dataset_size() == 1
assert train_data.get_batch_size() == 2
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 1
def test_usps_exception():
"""
Test error cases for USPSDataset
"""
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.USPSDataset(DATA_DIR, "train", num_shards=10)
ds.USPSDataset(DATA_DIR, "test", num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.USPSDataset(DATA_DIR, "train", shard_id=0)
ds.USPSDataset(DATA_DIR, "test", shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=-1)
ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=5)
ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id=5)
ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0)
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256)
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2)
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id="0")
ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id="0")
error_msg_8 = "invalid input shape"
with pytest.raises(RuntimeError, match=error_msg_8):
train_data = ds.USPSDataset(DATA_DIR, "train")
train_data = train_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
for _ in train_data.__iter__():
pass
test_data = ds.USPSDataset(DATA_DIR, "test")
test_data = test_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
for _ in test_data.__iter__():
pass
error_msg_9 = "failed to find USPS train data file"
with pytest.raises(RuntimeError, match=error_msg_9):
train_data = ds.USPSDataset(WRONG_DIR, "train")
for _ in train_data.__iter__():
pass
error_msg_10 = "failed to find USPS test data file"
with pytest.raises(RuntimeError, match=error_msg_10):
test_data = ds.USPSDataset(WRONG_DIR, "test")
for _ in test_data.__iter__():
pass
def test_usps_visualize(plot=False):
"""
Visualize USPSDataset results
"""
logger.info("Test USPSDataset visualization")
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (16, 16, 1)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 3
if plot:
visualize_dataset(image_list, label_list)
test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (16, 16, 1)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 3
if plot:
visualize_dataset(image_list, label_list)
def test_usps_usage():
"""
Validate USPSDataset image readings
"""
logger.info("Test USPSDataset usage flag")
def test_config(usage, path=None):
path = DATA_DIR if path is None else path
try:
data = ds.USPSDataset(path, usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("train") == 3
assert test_config("test") == 3
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
# change this directory to the folder that contains all USPS files
all_files_path = None
# the following tests on the entire datasets
if all_files_path is not None:
assert test_config("train", all_files_path) == 3
assert test_config("test", all_files_path) == 3
assert ds.USPSDataset(all_files_path, usage="train").get_dataset_size() == 3
assert ds.USPSDataset(all_files_path, usage="test").get_dataset_size() == 3
if __name__ == '__main__':
test_usps_content_check()
test_usps_basic()
test_usps_exception()
test_usps_visualize(plot=True)
test_usps_usage()