forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6VL] add new data operator USPS
This commit is contained in:
parent
72fb9207f8
commit
d7c702e9d4
|
@ -102,6 +102,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
#endif
|
||||
|
||||
|
@ -1198,6 +1199,14 @@ TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_f
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
USPSDataset::USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto ds = std::make_shared<USPSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
|
||||
shard_id, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
|
||||
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
#endif
|
||||
|
||||
|
@ -286,6 +287,18 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
|
||||
"to create an USPSNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id) {
|
||||
auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, nullptr);
|
||||
THROW_IF_ERROR(usps->ValidateParams());
|
||||
return usps;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string task, std::string usage,
|
||||
|
|
|
@ -14,6 +14,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
clue_op.cc
|
||||
csv_op.cc
|
||||
album_op.cc
|
||||
usps_op.cc
|
||||
mappable_leaf_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
cityscapes_op.cc
|
||||
|
|
|
@ -0,0 +1,350 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "debug/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr int64_t kUSPSImageHeight = 16;
|
||||
constexpr int64_t kUSPSImageWidth = 16;
|
||||
constexpr int64_t kUSPSImageChannel = 1;
|
||||
constexpr int64_t kUSPSImageSize = kUSPSImageHeight * kUSPSImageWidth * kUSPSImageChannel;
|
||||
|
||||
USPSOp::USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
|
||||
device_id),
|
||||
usage_(usage),
|
||||
dataset_dir_(dataset_dir),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
void USPSOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nUSPS directory: " << dataset_dir_
|
||||
<< "\nUSPS usage: " << usage_ << "\n\n";
|
||||
out << "\nData schema:\n";
|
||||
out << *data_schema_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status USPSOp::Init() {
|
||||
RETURN_IF_NOT_OK(this->GetFiles());
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(data_files_list_));
|
||||
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(data_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
int32_t worker_connector_size = cfg->worker_connector_size();
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int32_t num_devices = 1;
|
||||
const int32_t device_id = 0;
|
||||
bool shuffle = false;
|
||||
|
||||
auto op = std::make_shared<USPSOp>(dir, usage, std::move(schema), num_workers, worker_connector_size, num_samples,
|
||||
op_connector_size, shuffle, num_devices, device_id);
|
||||
RETURN_IF_NOT_OK(op->Init());
|
||||
// the logic of counting the number of samples
|
||||
for (auto data_file : op->FileNames()) {
|
||||
*count += op->CountRows(data_file);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t USPSOp::CountRows(const std::string &data_file) {
|
||||
std::ifstream data_file_reader;
|
||||
data_file_reader.open(data_file, std::ios::in);
|
||||
if (!data_file_reader.is_open()) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open file: " << data_file;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int64_t count = 0;
|
||||
while (std::getline(data_file_reader, line)) {
|
||||
if (!line.empty()) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
data_file_reader.close();
|
||||
return count;
|
||||
}
|
||||
|
||||
Status USPSOp::GetFiles() {
|
||||
auto real_dataset_dir = Common::GetRealPath(dataset_dir_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(), "Get real path failed: " + dataset_dir_);
|
||||
Path root_dir(real_dataset_dir.value());
|
||||
|
||||
const Path train_file_name("usps");
|
||||
const Path test_file_name("usps.t");
|
||||
|
||||
bool use_train = false;
|
||||
bool use_test = false;
|
||||
|
||||
if (usage_ == "train") {
|
||||
use_train = true;
|
||||
} else if (usage_ == "test") {
|
||||
use_test = true;
|
||||
} else if (usage_ == "all") {
|
||||
use_train = true;
|
||||
use_test = true;
|
||||
}
|
||||
|
||||
if (use_train) {
|
||||
Path train_path = root_dir / train_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(train_path.Exists() && !train_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS train data file: " + train_path.ToString());
|
||||
data_files_list_.emplace_back(train_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found train data file " << train_path.ToString() << ".";
|
||||
}
|
||||
|
||||
if (use_test) {
|
||||
Path test_path = root_dir / test_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(test_path.Exists() && !test_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS test data file: " + test_path.ToString());
|
||||
data_files_list_.emplace_back(test_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found test data file " << test_path.ToString() << ".";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream data_file_reader(data_file);
|
||||
if (!data_file_reader.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + data_file);
|
||||
}
|
||||
|
||||
int64_t rows_total = 0;
|
||||
std::string line;
|
||||
|
||||
while (getline(data_file_reader, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
// If read to the end offset of this file, break.
|
||||
if (rows_total >= end_offset) {
|
||||
break;
|
||||
}
|
||||
// Skip line before start offset.
|
||||
if (rows_total < start_offset) {
|
||||
rows_total++;
|
||||
continue;
|
||||
}
|
||||
|
||||
TensorRow tRow(1, nullptr);
|
||||
tRow.setPath({data_file});
|
||||
Status rc = LoadTensor(&line, &tRow);
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
|
||||
rows_total++;
|
||||
}
|
||||
|
||||
data_file_reader.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadTensor(std::string *line, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(line);
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
auto images_buffer = std::make_unique<unsigned char[]>(kUSPSImageSize);
|
||||
auto labels_buffer = std::make_unique<uint32_t[]>(1);
|
||||
if (images_buffer == nullptr || labels_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate memory for USPS buffer.";
|
||||
RETURN_STATUS_UNEXPECTED("Failed to allocate memory for USPS buffer.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(this->ParseLine(line, images_buffer, labels_buffer));
|
||||
|
||||
// create tensor
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
TensorShape image_tensor_shape = TensorShape({kUSPSImageHeight, kUSPSImageWidth, kUSPSImageChannel});
|
||||
auto pixels = &images_buffer[0];
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
|
||||
reinterpret_cast<unsigned char *>(pixels), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(labels_buffer[0], &label));
|
||||
|
||||
(*trow) = {std::move(image), std::move(label)};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
|
||||
const std::unique_ptr<uint32_t[]> &labels_buffer) {
|
||||
auto label = &labels_buffer[0];
|
||||
auto pixels = &images_buffer[0];
|
||||
|
||||
size_t pos = 0;
|
||||
int32_t split_num = 0;
|
||||
while ((pos = line->find(" ")) != std::string::npos) {
|
||||
split_num += 1;
|
||||
std::string item = line->substr(0, pos);
|
||||
|
||||
if (split_num == 1) {
|
||||
// the class label is 1~10 but we need 0~9
|
||||
*label = static_cast<uint32_t>(std::stoi(item)) - 1;
|
||||
} else {
|
||||
size_t split_pos = item.find(":");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_pos != std::string::npos, "Invalid data, USPS data file is corrupted.");
|
||||
// check pixel index
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(std::stoi(item.substr(0, split_pos)) == (split_num - 1),
|
||||
"Invalid data, USPS data file is corrupted.");
|
||||
|
||||
std::string pixel_str = item.substr(split_pos + 1, item.length() - split_pos);
|
||||
// transform the real pixel value from [-1, 1] to the integers within [0, 255]
|
||||
pixels[split_num - 2] = static_cast<uint8_t>((std::stof(pixel_str) + 1.0) / 2.0 * 255.0);
|
||||
}
|
||||
line->erase(0, pos + 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_num == (kUSPSImageSize + 1), "Invalid data, USPS data file is corrupted.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (num_rows_ == 0) {
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < data_files_list_.size(); ++i) {
|
||||
ss << " " << data_files_list_[i];
|
||||
}
|
||||
std::string file_list = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, data file may not be suitable to read with USPSDataset API. Check file: " +
|
||||
file_list);
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
int64_t start_offset = 0;
|
||||
int64_t end_offset = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
std::vector<std::pair<std::string, int64_t>> file_index;
|
||||
if (!i_keys.empty()) {
|
||||
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
|
||||
}
|
||||
} else {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
|
||||
}
|
||||
}
|
||||
for (auto file_info : file_index) {
|
||||
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock =
|
||||
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
}
|
||||
|
||||
pre_count += filename_numrows_[file_info.first];
|
||||
}
|
||||
|
||||
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class USPSOp : public NonMappableLeafOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param const std::string &dataset_dir - dir directory of USPS data file.
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the USPS dataset.
|
||||
// @param num_workers - number of worker threads reading data from tf_file files.
|
||||
// @param worker_connector_size - size of each internal queue.
|
||||
// @param num_samples - number of samples to read.
|
||||
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
|
||||
// @param shuffle_files - whether to shuffle the files before reading data.
|
||||
// @param num_devices - number of devices.
|
||||
// @param device_id - device id.
|
||||
USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Destructor.
|
||||
~USPSOp() = default;
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "USPSOp"; }
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned.
|
||||
Status Init() override;
|
||||
|
||||
// Function to count the number of samples in the USPS dataset.
|
||||
// @param const std::string &dir - path to the USPS directory.
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'.
|
||||
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
// @return Status - the error coed returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
|
||||
|
||||
// File names getter.
|
||||
// @return Vector of the input file names.
|
||||
std::vector<std::string> FileNames() { return data_files_list_; }
|
||||
|
||||
private:
|
||||
// Function to count the number of samples in one data file.
|
||||
// @param const std::string &data_file - path to the data file.
|
||||
// @return int64_t - the count result.
|
||||
int64_t CountRows(const std::string &data_file);
|
||||
|
||||
// Reads a data file and loads the data into multiple TensorRows.
|
||||
// @param data_file - the data file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param trow - image & label read into this tensor row.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadTensor(std::string *line, TensorRow *trow);
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard() override;
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @param i_keys - keys of file to fill to the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
|
||||
|
||||
// Get all files in the dataset_dir_.
|
||||
// @return Status - The status code returned.
|
||||
Status GetFiles();
|
||||
|
||||
// Parse a line to image and label.
|
||||
// @param line - the content of the row.
|
||||
// @param images_buffer - image destination.
|
||||
// @param labels_buffer - label destination.
|
||||
// @return Status - the status code returned.
|
||||
Status ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
|
||||
const std::unique_ptr<uint32_t[]> &labels_buffer);
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status - the error code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
const std::string usage_; // can be "all", "train" or "test".
|
||||
std::string dataset_dir_; // directory of data files.
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
std::vector<std::string> data_files_list_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_
|
|
@ -91,6 +91,7 @@ constexpr char kMnistNode[] = "MnistDataset";
|
|||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
constexpr char kUSPSNode[] = "USPSDataset";
|
||||
constexpr char kVOCNode[] = "VOCDataset";
|
||||
|
||||
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
|
|
|
@ -19,6 +19,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
random_node.cc
|
||||
text_file_node.cc
|
||||
tf_record_node.cc
|
||||
usps_node.cc
|
||||
voc_node.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
USPSNode::USPSNode(std::string dataset_dir, std::string usage, int32_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {
|
||||
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
|
||||
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
|
||||
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
|
||||
// PreBuildSampler is phased out, this can be cleaned up.
|
||||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> USPSNode::Copy() {
|
||||
auto node = std::make_shared<USPSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void USPSNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(dataset dir:" + dataset_dir_ + ", usage:" + usage_ +
|
||||
", num_shards:" + std::to_string(num_shards_) + ", shard_id:" + std::to_string(shard_id_) +
|
||||
", num_samples:" + std::to_string(num_samples_) + ")");
|
||||
}
|
||||
|
||||
Status USPSNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("USPSNode", usage_, {"train", "test", "all"}));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "USPSNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSNode", num_shards_, shard_id_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
auto op = std::make_shared<USPSOp>(dataset_dir_, usage_, std::move(schema), num_workers_, worker_connector_size_,
|
||||
num_samples_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
RETURN_IF_NOT_OK(op->Init());
|
||||
|
||||
// If a global shuffle is used for USPS, it will inject a shuffle op over the USPS.
|
||||
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
|
||||
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset USPS's shuffle
|
||||
// option to false.
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status USPSNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status USPSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size = num_samples_;
|
||||
RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// USPS by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status USPSNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults so
|
||||
// that this USPS node will produce the full set of data into the cache.
|
||||
Status USPSNode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class USPSNode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
USPSNode(std::string dataset_dir, std::string usage, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~USPSNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kUSPSNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
int32_t NumSamples() const { return num_samples_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
int32_t NumShards() const { return num_shards_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
int32_t ShardId() const { return shard_id_; }
|
||||
|
||||
/// \brief Getter functions.
|
||||
ShuffleMode Shuffle() const { return shuffle_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief USPS by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup.
|
||||
/// \return Status of the function.
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults
|
||||
/// so that this USPS node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function.
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
int32_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_
|
|
@ -2386,6 +2386,56 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase
|
|||
return ds;
|
||||
}
|
||||
|
||||
/// \class USPSDataset
|
||||
/// \brief A source dataset that reads and parses USPS datasets.
|
||||
class USPSDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of USPSDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||
/// (Default = 0 means all samples).
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kFiles - Shuffle files only.
|
||||
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified (Default = 0).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
explicit USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of USPSDataset.
|
||||
~USPSDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a USPSDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Usage of USPS, can be "train", "test" or "all" (Default = "all").
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||
/// (Default = 0 means all samples).
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kFiles - Shuffle files only.
|
||||
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified (Default = 0).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current USPSDataset.
|
||||
inline std::shared_ptr<USPSDataset> USPS(const std::string &dataset_dir, const std::string &usage = "all",
|
||||
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
|
||||
int32_t num_shards = 1, int32_t shard_id = 0,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<USPSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
|
||||
shard_id, cache);
|
||||
}
|
||||
|
||||
/// \class VOCDataset
|
||||
/// \brief A source dataset for reading and parsing VOC dataset.
|
||||
class VOCDataset : public Dataset {
|
||||
|
|
|
@ -46,6 +46,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class RandomDataDataset;
|
||||
friend class TextFileDataset;
|
||||
friend class TFRecordDataset;
|
||||
friend class USPSDataset;
|
||||
friend class VOCDataset;
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
|
||||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -4712,6 +4712,103 @@ class Schema:
|
|||
return schema_obj.cpp_schema.get_num_rows()
|
||||
|
||||
|
||||
class USPSDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the USPS dataset.
|
||||
|
||||
The generated dataset has two columns: :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Usage of this dataset, can be "train", "test" or "all". "train" will read from 7,291
|
||||
train samples, "test" will read from 2,007 test samples, "all" will read from all 9,298 samples.
|
||||
(default=None, will read all samples)
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
|
||||
(default=Shuffle.GLOBAL).
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Otherwise, there are two levels of shuffling:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||
|
||||
- Shuffle.FILES: Shuffle files only.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir is not valid or does not exist or does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If usage is invalid.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Examples:
|
||||
>>> usps_dataset_dir = "/path/to/usps_dataset_directory"
|
||||
>>>
|
||||
>>> # Read 3 samples from USPS dataset
|
||||
>>> dataset = ds.USPSDataset(dataset_dir=usps_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # Note: In USPS dataset, each dictionary has keys "image" and "label"
|
||||
|
||||
About USPS dataset:
|
||||
|
||||
USPS is a digit dataset automatically scanned from envelopes by the U.S. Postal Service
|
||||
containing a total of 9,298 16×16 pixel grayscale samples.
|
||||
The images are centered, normalized and show a broad range of font styles.
|
||||
|
||||
Here is the original USPS dataset structure.
|
||||
You can download and unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
.
|
||||
└── usps_dataset_dir
|
||||
├── usps
|
||||
├── usps.t
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@article{hull1994database,
|
||||
title={A database for handwritten text recognition research},
|
||||
author={Hull, Jonathan J.},
|
||||
journal={IEEE Transactions on pattern analysis and machine intelligence},
|
||||
volume={16},
|
||||
number={5},
|
||||
pages={550--554},
|
||||
year={1994},
|
||||
publisher={IEEE}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_usps_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
|
||||
num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.USPSNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
|
||||
self.shard_id)
|
||||
|
||||
|
||||
class VOCDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing VOC dataset.
|
||||
|
|
|
@ -151,6 +151,33 @@ def check_tfrecorddataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_usps_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_vocdataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_save.cc
|
||||
c_api_dataset_textfile_test.cc
|
||||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_usps_test.cc
|
||||
c_api_dataset_voc_test.cc
|
||||
c_api_datasets_test.cc
|
||||
c_api_epoch_ctrl_test.cc
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSTrainDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDataset.";
|
||||
|
||||
// Create a USPS Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSTestDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTestDataset.";
|
||||
|
||||
// Create a USPS Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "test");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSAllDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSAllDataset.";
|
||||
|
||||
// Create a USPS Test Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "all");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetWithPipeline.";
|
||||
|
||||
// Create two USPS Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds1 = USPS(folder_path, "train");
|
||||
std::shared_ptr<Dataset> ds2 = USPS(folder_path, "train");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 1;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestGetUSPSDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetUSPSTrainDatasetSize.";
|
||||
|
||||
// Create a USPS Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetGetters.";
|
||||
|
||||
// Create a USPS Train Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<16,16,1>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetFail.";
|
||||
|
||||
// Create a USPS Dataset
|
||||
std::shared_ptr<Dataset> ds = USPS("", "train");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid USPS input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestUSPSDatasetWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetWithInvalidUsageFail.";
|
||||
|
||||
// Create a USPS Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testUSPSDataset/";
|
||||
std::shared_ptr<Dataset> ds = USPS(folder_path, "validation");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid USPS input, validation is not a valid usage
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
1 1:0.600000 2:-0.81176 3:0.937254 4:0.160784 5:-0.25490 6:0.615686 7:-0.02745 8:0.866666 9:-0.03529 10:0.137254 11:-0.52941 12:-0.20784 13:0.411764 14:-0.71764 15:-0.59215 16:-0.08235 17:-0.6 18:0.505882 19:-0.09803 20:-0.22352 21:-0.30980 22:0.725490 23:0.294117 24:-0.70196 25:0.929411 26:-0.30196 27:0.247058 28:0.035294 29:-0.46666 30:-0.86666 31:-0.35686 32:-0.54509 33:0.992156 34:-0.04313 35:-0.84313 36:-0.39607 37:-0.16078 38:-0.72549 39:-1.0 40:-0.10588 41:0.090196 42:-0.29411 43:-0.42745 44:-0.63137 45:-0.03529 46:-0.44313 47:-0.16078 48:-0.62352 49:0.254901 50:0.168627 51:0.545098 52:-0.37254 53:-0.89019 54:-0.63137 55:-0.18431 56:0.631372 57:0.749019 58:-0.78039 59:0.058823 60:-0.45882 61:-0.47450 62:-0.21568 63:0.890196 64:-0.69411 65:0.976470 66:-0.81176 67:0.419607 68:-0.95294 69:-0.56078 70:-0.11372 71:0.145098 72:0.105882 73:0.537254 74:-0.73333 75:-0.30196 76:-0.56862 77:0.811764 78:0.552941 79:0.152941 80:-0.88235 81:-0.69411 82:0.647058 83:-0.11372 84:-0.99215 85:-0.85098 86:-0.43529 87:0.278431 88:-0.83529 89:0.427450 90:-0.30196 91:0.725490 92:0.929411 93:0.239215 94:0.176470 95:0.709803 96:-0.00392 97:0.317647 98:-0.17647 99:0.098039 100:-0.78039 101:-0.21568 102:0.976470 103:0.670588 104:0.349019 105:0.623529 106:0.631372 107:0.529411 108:-0.53725 109:0.694117 110:0.505882 111:-0.70196 112:0.592156 113:0.458823 114:-0.67058 115:-0.37254 116:0.450980 117:-0.67058 118:0.482352 119:-0.62352 120:0.741176 121:0.552941 122:0.105882 123:-0.12941 124:0.278431 125:-0.36470 126:0.827450 127:0.537254 128:0.827450 129:-0.49803 130:0.341176 131:-0.74901 132:-0.85098 133:0.247058 134:-0.75686 135:0.678431 136:0.670588 137:-1.0 138:-0.23137 139:-0.30196 140:0.223529 141:0.325490 142:-0.63137 143:-0.99215 144:-0.45882 145:0.474509 146:0.341176 147:0.552941 148:-0.11372 149:-0.34117 150:-0.86666 151:-0.56862 152:0.168627 153:-0.11372 154:0.098039 155:-0.94509 156:-0.01960 157:-0.45098 158:0.168627 159:-0.94509 160:-0.84313 161:-0.44313 162:-0.05882 163:-0.05882 164:-0.49803 165:0.105882 166:-0.96862 167:0.725490 168:-0.30196 169:-0.21568 170:0.803921 171:0.647058 172:-0.46666 173:-0.52941 174:0.796078 175:0.654901 176:-0.81176 177:-0.30196 178:0.811764 179:-0.63137 180:-0.41176 181:-0.12156 182:-1.0 183:-0.59215 184:0.631372 185:-0.41960 186:0.937254 187:0.396078 188:0.137254 189:-0.50588 190:0.443137 191:0.474509 192:0.121568 193:0.074509 194:-0.06666 195:-0.63921 196:0.286274 197:0.019607 198:-0.14509 199:-0.56862 200:0.537254 201:0.333333 202:0.027450 203:-0.95294 204:0.678431 205:0.474509 206:0.027450 207:0.349019 208:-0.91372 209:0.772549 210:0.223529 211:-0.96078 212:-0.38823 213:-0.50588 214:0.764705 215:0.490196 216:0.458823 217:0.945098 218:-0.17647 219:0.560784 220:-0.01960 221:0.474509 222:-0.84313 223:0.043137 224:0.615686 225:0.388235 226:0.082352 227:-0.45098 228:-0.97647 229:-0.58431 230:0.905882 231:0.372549 232:-0.23137 233:-0.07450 234:0.262745 235:-0.70980 236:0.733333 237:-0.51372 238:-0.63921 239:0.443137 240:0.678431 241:-0.64705 242:0.631372 243:-0.98431 244:0.411764 245:0.019607 246:0.239215 247:0.380392 248:-0.41960 249:0.678431 250:-0.81960 251:0.364705 252:-0.93725 253:-0.05882 254:0.623529 255:0.160784 256:-0.19999
|
||||
7 1:0.294117 2:0.576470 3:-0.98431 4:-0.19999 5:-0.90588 6:0.074509 7:-0.54509 8:-0.98431 9:-0.67843 10:-0.04313 11:-0.56862 12:0.733333 13:-0.85098 14:0.082352 15:-0.50588 16:-0.37254 17:0.160784 18:0.772549 19:-0.69411 20:-0.07450 21:-0.90588 22:0.443137 23:-0.25490 24:0.678431 25:-0.38039 26:-0.87450 27:0.286274 28:-0.67058 29:0.623529 30:0.639215 31:-0.86666 32:0.545098 33:-0.22352 34:-0.05098 35:-0.81176 36:-0.15294 37:0.380392 38:0.058823 39:-0.26274 40:-0.85882 41:0.945098 42:0.945098 43:0.231372 44:0.356862 45:0.874509 46:0.654901 47:0.647058 48:-0.80392 49:0.396078 50:0.129411 51:-0.99215 52:0.945098 53:-0.01176 54:0.199999 55:-0.34901 56:-0.58431 57:0.301960 58:-0.68627 59:0.639215 60:0.176470 61:0.011764 62:0.796078 63:0.513725 64:-0.30980 65:0.615686 66:0.356862 67:0.545098 68:-0.85882 69:0.199999 70:-0.65490 71:-0.15294 72:-0.09803 73:0.090196 74:-0.11372 75:-0.55294 76:0.294117 77:-0.38823 78:0.184313 79:0.458823 80:0.247058 81:0.945098 82:0.976470 83:-0.16078 84:0.380392 85:-0.51372 86:0.984313 87:0.945098 88:0.215686 89:0.207843 90:0.301960 91:-0.67843 92:-0.34117 93:0.741176 94:0.372549 95:-0.70196 96:0.066666 97:0.890196 98:0.600000 99:-0.15294 100:-0.57647 101:0.678431 102:-0.25490 103:0.231372 104:-0.10588 105:0.788235 106:0.796078 107:0.301960 108:0.121568 109:0.356862 110:-0.12941 111:0.231372 112:0.066666 113:0.709803 114:0.137254 115:0.905882 116:-0.82745 117:0.443137 118:0.929411 119:-0.05098 120:0.537254 121:-0.53725 122:-0.34901 123:0.850980 124:0.576470 125:0.670588 126:0.380392 127:-0.09803 128:0.952941 129:0.372549 130:-0.96862 131:0.882352 132:0.231372 133:-0.59215 134:0.341176 135:0.694117 136:0.192156 137:-0.03529 138:-0.60784 139:0.450980 140:0.372549 141:-0.10588 142:-0.44313 143:0.741176 144:-0.35686 145:-0.41176 146:0.411764 147:-0.05098 148:-0.57647 149:-0.37254 150:-0.36470 151:0.082352 152:-0.05882 153:-0.29411 154:0.788235 155:0.576470 156:-0.79607 157:0.254901 158:-0.09019 159:0.427450 160:0.301960 161:-0.10588 162:0.952941 163:0.466666 164:-0.59215 165:-0.95294 166:0.435294 167:-0.12156 168:-0.70980 169:0.262745 170:0.960784 171:-0.78039 172:-0.63921 173:-1.0 174:-0.53725 175:0.678431 176:-0.41176 177:-0.94509 178:0.129411 179:0.372549 180:0.521568 181:-0.75686 182:0.615686 183:-0.04313 184:0.262745 185:-0.37254 186:0.145098 187:0.537254 188:-0.12156 189:-0.01176 190:0.560784 191:0.011764 192:0.160784 193:0.647058 194:-0.57647 195:0.835294 196:0.905882 197:0.278431 198:0.811764 199:-0.86666 200:-0.78823 201:-0.44313 202:0.600000 203:0.505882 204:0.254901 205:0.215686 206:-0.33333 207:-0.65490 208:-0.33333 209:0.317647 210:-0.00392 211:-0.60784 212:-0.63921 213:0.890196 214:-0.64705 215:-0.56078 216:-0.59215 217:0.654901 218:0.168627 219:0.537254 220:-0.76470 221:0.301960 222:-0.33333 223:-0.74117 224:0.098039 225:0.152941 226:0.694117 227:-0.18431 228:0.843137 229:0.592156 230:-0.38039 231:-0.30980 232:0.160784 233:0.262745 234:-0.96862 235:-0.41960 236:0.380392 237:0.623529 238:0.858823 239:-0.24705 240:-0.38823 241:0.356862 242:0.576470 243:-0.38823 244:-0.60784 245:-0.24705 246:-0.54509 247:0.317647 248:0.341176 249:0.913725 250:0.003921 251:-0.34117 252:0.513725 253:-0.96078 254:-0.84313 255:-0.93725 256:-0.92941
|
||||
1 1:0.411764 2:0.929411 3:-0.03529 4:0.231372 5:-0.30980 6:0.286274 7:-0.58431 8:0.937254 9:0.388235 10:-0.89803 11:-0.16862 12:0.756862 13:-0.67843 14:-0.89019 15:0.749019 16:0.074509 17:-0.25490 18:0.741176 19:-0.45098 20:-0.22352 21:0.788235 22:0.576470 23:-0.59215 24:-0.51372 25:0.560784 26:-0.97647 27:0.670588 28:-0.27843 29:-0.64705 30:0.984313 31:-0.22352 32:0.027450 33:-0.96862 34:-0.78039 35:0.694117 36:0.349019 37:-0.53725 38:-0.37254 39:-0.49019 40:-0.68627 41:0.325490 42:0.992156 43:0.098039 44:0.270588 45:0.882352 46:0.043137 47:-0.39607 48:-0.43529 49:-0.75686 50:-0.21568 51:-0.86666 52:-0.67058 53:-0.62352 54:0.607843 55:0.552941 56:0.380392 57:0.568627 58:-0.31764 59:0.639215 60:0.607843 61:-0.37254 62:0.176470 63:0.098039 64:-0.53725 65:-0.71764 66:-0.30196 67:-0.44313 68:0.568627 69:-0.85098 70:0.505882 71:0.537254 72:-0.41960 73:0.474509 74:0.254901 75:-0.16862 76:-0.44313 77:-0.73333 78:-0.08235 79:0.537254 80:-0.58431 81:-0.96862 82:0.733333 83:-0.01176 84:-0.87450 85:0.701960 86:0.890196 87:-0.61568 88:0.694117 89:-0.83529 90:-0.34117 91:-0.59215 92:0.952941 93:0.239215 94:-0.87450 95:0.803921 96:0.905882 97:0.545098 98:-0.32549 99:-0.38823 100:1.0 101:-0.50588 102:0.835294 103:0.254901 104:-0.70196 105:-0.77254 106:-0.55294 107:0.474509 108:-0.78039 109:-0.88235 110:0.388235 111:-0.52941 112:0.443137 113:0.129411 114:-0.89019 115:0.145098 116:-0.16862 117:0.568627 118:-0.27843 119:-0.23137 120:-0.05882 121:0.003921 122:0.662745 123:-0.87450 124:0.474509 125:0.639215 126:-0.97647 127:0.701960 128:0.247058 129:0.545098 130:-0.70980 131:0.913725 132:-0.63921 133:-0.66274 134:-0.16862 135:-0.63921 136:-0.05098 137:-0.31764 138:-0.54509 139:-0.83529 140:-0.83529 141:0.215686 142:-0.28627 143:-0.39607 144:-0.06666 145:-0.81176 146:0.145098 147:-0.12941 148:-0.44313 149:-0.10588 150:0.199999 151:0.537254 152:-0.36470 153:-0.48235 154:-0.77254 155:-0.96862 156:-0.67843 157:0.850980 158:-0.45882 159:-0.99215 160:0.231372 161:-0.54509 162:-0.77254 163:0.937254 164:-0.12941 165:-0.32549 166:-0.01176 167:-0.40392 168:0.607843 169:0.835294 170:-0.83529 171:0.082352 172:0.450980 173:-0.85098 174:0.207843 175:0.239215 176:0.827450 177:0.490196 178:-0.68627 179:-0.20784 180:-0.46666 181:0.474509 182:-0.97647 183:-0.30980 184:0.145098 185:0.717647 186:0.411764 187:-0.67843 188:-0.40392 189:-0.90588 190:-0.87450 191:0.113725 192:0.333333 193:0.952941 194:-0.58431 195:-0.83529 196:-0.32549 197:0.882352 198:0.301960 199:-0.82745 200:-0.53725 201:0.952941 202:0.600000 203:0.921568 204:-0.40392 205:0.945098 206:-0.31764 207:0.121568 208:-0.36470 209:-0.19999 210:-0.01176 211:0.011764 212:-0.31764 213:-0.45098 214:-0.94509 215:-0.05098 216:-0.42745 217:0.749019 218:-0.64705 219:-0.81176 220:0.505882 221:-0.66274 222:0.623529 223:0.882352 224:-0.18431 225:0.835294 226:0.968627 227:0.050980 228:0.160784 229:-0.41960 230:0.662745 231:-0.26274 232:-0.64705 233:-0.63137 234:-0.89019 235:0.694117 236:-0.56862 237:0.819607 238:0.686274 239:-0.18431 240:-0.57647 241:-0.6 242:0.976470 243:0.160784 244:-0.68627 245:0.223529 246:0.615686 247:0.074509 248:-0.64705 249:0.717647 250:0.921568 251:0.058823 252:0.349019 253:-0.31764 254:-0.13725 255:-0.49019 256:-0.57647
|
|
@ -0,0 +1,3 @@
|
|||
8 1:-0.41176 2:-0.22352 3:0.129411 4:0.380392 5:0.545098 6:-0.44313 7:-0.78039 8:0.725490 9:-0.36470 10:0.529411 11:0.309803 12:-0.34117 13:-0.27058 14:0.521568 15:0.780392 16:0.890196 17:-0.76470 18:0.356862 19:-0.09019 20:-0.49019 21:0.694117 22:-0.09019 23:0.301960 24:-0.44313 25:-0.87450 26:0.709803 27:0.749019 28:-0.89019 29:0.223529 30:-0.89019 31:0.490196 32:-0.85098 33:0.694117 34:0.090196 35:0.035294 36:0.788235 37:0.333333 38:0.419607 39:0.678431 40:-0.74117 41:0.137254 42:0.756862 43:0.576470 44:-0.97647 45:-0.38823 46:-0.36470 47:-0.67843 48:0.678431 49:-0.58431 50:-0.58431 51:0.309803 52:-0.52941 53:-0.67843 54:0.403921 55:0.858823 56:-0.25490 57:0.717647 58:-0.47450 59:-0.92941 60:0.529411 61:-0.37254 62:-0.89019 63:-0.51372 64:-0.91372 65:0.811764 66:-0.78039 67:1.0 68:0.600000 69:-0.20784 70:0.239215 71:0.027450 72:0.388235 73:-0.38823 74:-0.81960 75:-0.24705 76:-0.25490 77:-0.09803 78:0.568627 79:-0.40392 80:0.882352 81:-0.44313 82:0.199999 83:0.568627 84:-0.88235 85:0.066666 86:0.035294 87:0.490196 88:0.701960 89:-0.42745 90:0.560784 91:-1.0 92:-0.96078 93:0.631372 94:0.552941 95:0.960784 96:-0.92156 97:0.913725 98:-0.04313 99:0.843137 100:0.741176 101:0.364705 102:0.654901 103:0.529411 104:0.709803 105:-0.69411 106:0.670588 107:0.976470 108:-0.28627 109:0.482352 110:0.529411 111:0.670588 112:-0.46666 113:0.443137 114:-0.52156 115:0.168627 116:-0.05882 117:0.145098 118:-0.05098 119:0.811764 120:-0.91372 121:0.113725 122:0.725490 123:0.427450 124:-0.00392 125:0.725490 126:0.505882 127:0.333333 128:-0.22352 129:0.490196 130:-0.78039 131:0.458823 132:-0.11372 133:0.137254 134:0.615686 135:-0.54509 136:-0.81176 137:-0.23921 138:-0.62352 139:-0.87450 140:0.560784 141:-0.92941 142:0.858823 143:0.364705 144:0.035294 145:0.176470 146:0.921568 147:0.552941 148:-0.60784 149:0.396078 150:0.168627 151:0.254901 152:-0.40392 153:-0.27058 154:0.396078 155:-0.29411 156:0.011764 157:0.003921 158:-0.75686 159:-0.10588 160:0.741176 161:-0.06666 162:-0.97647 163:-0.78039 164:0.356862 165:0.341176 166:-0.05882 167:0.576470 168:-0.6 169:0.647058 170:-0.80392 171:-0.12941 172:0.513725 173:0.701960 174:0.937254 175:-0.49019 176:-0.70980 177:0.960784 178:0.474509 179:-0.17647 180:-0.37254 181:-0.23137 182:-0.06666 183:-0.37254 184:-0.76470 185:0.474509 186:0.168627 187:-0.71764 188:-0.48235 189:0.505882 190:-0.13725 191:0.623529 192:0.333333 193:-0.43529 194:-0.34901 195:-0.19215 196:0.717647 197:0.513725 198:-0.25490 199:-0.30196 200:0.098039 201:0.937254 202:-0.63921 203:-0.37254 204:-0.99215 205:0.301960 206:0.482352 207:-0.21568 208:0.050980 209:0.419607 210:0.756862 211:0.701960 212:0.294117 213:-0.45882 214:0.341176 215:0.992156 216:0.003921 217:0.411764 218:-0.33333 219:0.427450 220:-0.97647 221:0.545098 222:0.341176 223:-0.90588 224:0.811764 225:0.176470 226:-0.78823 227:0.050980 228:0.733333 229:0.286274 230:-0.78039 231:-0.89803 232:0.490196 233:0.788235 234:-0.52156 235:-0.78039 236:0.035294 237:0.600000 238:-0.12156 239:-0.09803 240:-0.35686 241:-0.82745 242:-0.05098 243:-0.67058 244:-0.74901 245:-0.30980 246:-0.94509 247:0.419607 248:-0.97647 249:-0.52941 250:-0.56078 251:-0.96078 252:-0.83529 253:0.905882 254:-0.32549 255:-0.68627 256:0.215686
|
||||
7 1:-0.33333 2:0.380392 3:0.866666 4:-0.58431 5:0.403921 6:0.145098 7:0.074509 8:0.027450 9:-0.31764 10:0.239215 11:0.333333 12:-0.38823 13:0.662745 14:-0.56862 15:-0.27058 16:0.505882 17:-0.79607 18:0.898039 19:-0.45882 20:-0.98431 21:0.152941 22:-0.03529 23:0.270588 24:0.866666 25:-0.95294 26:0.803921 27:0.694117 28:0.450980 29:0.733333 30:-0.15294 31:0.286274 32:-0.95294 33:-0.93725 34:0.654901 35:-0.25490 36:-0.81176 37:0.678431 38:0.545098 39:0.035294 40:-0.05098 41:-0.46666 42:-0.05882 43:0.764705 44:0.545098 45:0.866666 46:0.631372 47:0.356862 48:-1.0 49:-0.92941 50:-0.44313 51:-0.30196 52:-0.02745 53:0.356862 54:-0.75686 55:0.741176 56:-0.54509 57:-0.73333 58:-0.52941 59:-0.70980 60:-0.85882 61:0.929411 62:0.223529 63:0.654901 64:-0.45098 65:0.552941 66:-0.03529 67:0.498039 68:-0.04313 69:0.600000 70:0.490196 71:0.403921 72:1.0 73:0.741176 74:-0.86666 75:-0.52941 76:-0.17647 77:-0.05098 78:-0.96862 79:0.105882 80:-0.34117 81:0.403921 82:-0.56078 83:0.796078 84:-0.05882 85:0.521568 86:-0.41960 87:0.262745 88:-0.49803 89:0.035294 90:-0.09803 91:0.356862 92:-0.92156 93:-0.70196 94:-0.59215 95:-0.17647 96:-0.96862 97:0.349019 98:-0.23137 99:0.984313 100:0.207843 101:0.011764 102:-0.19999 103:-0.83529 104:0.121568 105:0.058823 106:-0.19215 107:0.670588 108:-0.34117 109:0.890196 110:0.921568 111:-0.65490 112:0.772549 113:0.741176 114:-0.29411 115:-0.64705 116:-0.49803 117:0.058823 118:-0.6 119:0.035294 120:-0.53725 121:-0.97647 122:0.811764 123:0.458823 124:0.945098 125:0.262745 126:0.152941 127:-0.06666 128:0.788235 129:-0.89803 130:-0.26274 131:0.466666 132:0.043137 133:0.607843 134:0.984313 135:-0.22352 136:0.874509 137:0.223529 138:0.835294 139:-0.75686 140:0.858823 141:-0.72549 142:0.254901 143:0.262745 144:0.105882 145:0.192156 146:0.772549 147:-0.19999 148:-0.28627 149:-0.40392 150:-0.96862 151:-0.47450 152:-0.34117 153:0.662745 154:0.615686 155:-0.46666 156:-0.26274 157:0.529411 158:0.741176 159:-0.87450 160:-0.41960 161:0.921568 162:-0.28627 163:0.247058 164:-0.12941 165:0.513725 166:-0.84313 167:-0.78039 168:0.301960 169:0.121568 170:0.286274 171:0.129411 172:0.623529 173:0.976470 174:-0.39607 175:0.513725 176:-0.63921 177:0.717647 178:0.388235 179:0.443137 180:0.184313 181:-0.49019 182:-0.95294 183:-0.53725 184:-0.29411 185:0.992156 186:0.003921 187:0.223529 188:-0.44313 189:0.443137 190:-0.69411 191:-0.93725 192:-0.54509 193:-0.96078 194:-0.52156 195:-0.06666 196:-0.80392 197:0.913725 198:0.380392 199:-0.35686 200:0.309803 201:0.145098 202:0.913725 203:0.741176 204:0.262745 205:0.952941 206:-0.66274 207:0.827450 208:-0.38823 209:0.105882 210:0.003921 211:0.325490 212:-0.41960 213:0.615686 214:0.521568 215:0.686274 216:-0.48235 217:-0.82745 218:-0.09803 219:-0.19215 220:0.639215 221:-0.09803 222:0.733333 223:-0.35686 224:0.286274 225:0.498039 226:0.152941 227:0.749019 228:0.294117 229:0.505882 230:-0.30196 231:-0.42745 232:0.639215 233:0.137254 234:0.223529 235:-0.6 236:-0.82745 237:0.349019 238:0.388235 239:0.631372 240:-0.41176 241:-0.70980 242:0.090196 243:-0.99215 244:0.270588 245:0.011764 246:0.968627 247:0.850980 248:-0.16862 249:0.686274 250:0.537254 251:-0.37254 252:0.647058 253:0.152941 254:0.717647 255:-0.45098 256:0.623529
|
||||
6 1:0.443137 2:-0.43529 3:0.223529 4:-0.88235 5:0.694117 6:0.435294 7:-0.41960 8:-0.03529 9:0.552941 10:0.709803 11:-0.58431 12:0.701960 13:-0.23137 14:-0.03529 15:-0.55294 16:1.0 17:0.819607 18:-0.19215 19:0.482352 20:0.058823 21:-0.68627 22:-0.81960 23:0.121568 24:0.709803 25:-0.99215 26:0.474509 27:0.905882 28:-0.06666 29:-0.85098 30:-0.90588 31:-0.67843 32:-0.13725 33:-0.56078 34:-0.36470 35:0.827450 36:0.396078 37:0.388235 38:0.803921 39:0.647058 40:-0.50588 41:-0.84313 42:0.466666 43:-0.18431 44:0.207843 45:0.694117 46:0.427450 47:0.827450 48:0.890196 49:0.129411 50:-0.02745 51:-0.20784 52:0.090196 53:-0.85882 54:-0.00392 55:0.105882 56:-0.79607 57:-0.20784 58:0.874509 59:0.380392 60:-0.55294 61:0.098039 62:0.113725 63:0.121568 64:-0.10588 65:0.301960 66:0.866666 67:0.654901 68:0.929411 69:-0.45098 70:0.984313 71:-0.23921 72:0.027450 73:0.223529 74:-0.74117 75:0.709803 76:-0.63137 77:0.309803 78:-0.13725 79:-0.31764 80:1.0 81:0.647058 82:-1.0 83:-0.03529 84:-0.82745 85:0.937254 86:-0.74901 87:0.945098 88:-0.95294 89:0.168627 90:0.121568 91:0.184313 92:-0.94509 93:0.756862 94:0.631372 95:-0.32549 96:-0.27843 97:0.082352 98:-0.12156 99:0.372549 100:0.725490 101:0.176470 102:-0.63921 103:-0.38039 104:-0.23137 105:0.341176 106:0.050980 107:-0.75686 108:-0.72549 109:0.286274 110:-0.96862 111:0.850980 112:0.223529 113:-0.93725 114:-0.79607 115:0.027450 116:-0.05098 117:-0.51372 118:0.560784 119:0.490196 120:0.223529 121:0.764705 122:-0.90588 123:-0.37254 124:0.403921 125:0.521568 126:-0.33333 127:-0.65490 128:-0.63137 129:-0.00392 130:0.552941 131:0.419607 132:0.733333 133:-0.78823 134:-0.71764 135:-0.52156 136:-0.38823 137:-0.21568 138:-0.69411 139:0.850980 140:-0.73333 141:0.223529 142:-0.54509 143:0.184313 144:0.403921 145:0.207843 146:-0.14509 147:0.764705 148:0.466666 149:-0.51372 150:0.709803 151:0.992156 152:-0.05882 153:-0.49803 154:0.647058 155:-0.59215 156:-0.98431 157:0.890196 158:-0.74117 159:0.662745 160:0.764705 161:0.192156 162:-0.6 163:0.027450 164:0.294117 165:0.788235 166:-0.39607 167:-0.12941 168:-0.94509 169:0.278431 170:-0.60784 171:0.968627 172:0.199999 173:-0.92941 174:0.984313 175:-0.50588 176:0.262745 177:0.247058 178:-0.46666 179:-0.20784 180:0.199999 181:-0.85882 182:-0.48235 183:0.450980 184:-0.79607 185:-0.41176 186:0.992156 187:-0.92941 188:-0.30980 189:-0.18431 190:0.811764 191:-0.88235 192:0.254901 193:-0.62352 194:0.082352 195:0.019607 196:0.349019 197:0.458823 198:-0.01960 199:-0.85098 200:-0.15294 201:0.160784 202:-0.36470 203:-0.16078 204:0.035294 205:-0.63921 206:0.898039 207:-0.60784 208:0.788235 209:0.803921 210:-0.28627 211:0.662745 212:-0.06666 213:-0.74117 214:-0.81960 215:0.992156 216:-0.89803 217:0.035294 218:-0.43529 219:-0.69411 220:-0.55294 221:0.450980 222:-0.59215 223:0.890196 224:0.780392 225:0.137254 226:0.545098 227:0.890196 228:0.074509 229:0.419607 230:0.372549 231:0.623529 232:-0.81960 233:0.568627 234:0.215686 235:0.552941 236:-0.43529 237:0.356862 238:-0.12941 239:0.929411 240:0.819607 241:-0.37254 242:-0.81960 243:0.976470 244:-0.09019 245:0.145098 246:-0.12941 247:0.780392 248:0.129411 249:-0.83529 250:-0.60784 251:-0.57647 252:-0.20784 253:0.937254 254:0.662745 255:-0.41960 256:-1.0
|
|
@ -0,0 +1,308 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test USPS dataset operators
|
||||
"""
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testUSPSDataset"
|
||||
WRONG_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
|
||||
def load_usps(path, usage):
|
||||
"""
|
||||
load USPS data
|
||||
"""
|
||||
assert usage in ["train", "test"]
|
||||
if usage == "train":
|
||||
data_path = os.path.realpath(os.path.join(path, "usps"))
|
||||
elif usage == "test":
|
||||
data_path = os.path.realpath(os.path.join(path, "usps.t"))
|
||||
|
||||
with open(data_path, 'r') as f:
|
||||
raw_data = [line.split() for line in f.readlines()]
|
||||
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
|
||||
images = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16, 1))
|
||||
images = ((cast(np.ndarray, images) + 1) / 2 * 255).astype(dtype=np.uint8)
|
||||
labels = [int(d[0]) - 1 for d in raw_data]
|
||||
return images, labels
|
||||
|
||||
|
||||
def visualize_dataset(images, labels):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
num_samples = len(images)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(labels[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_usps_content_check():
|
||||
"""
|
||||
Validate USPSDataset image readings
|
||||
"""
|
||||
logger.info("Test USPSDataset Op with content check")
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=10, shuffle=False)
|
||||
images, labels = load_usps(DATA_DIR, "train")
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
for m in range(16):
|
||||
for n in range(16):
|
||||
assert (data["image"][m, n, 0] != 0 or images[i][m, n, 0] != 255) and \
|
||||
(data["image"][m, n, 0] != 255 or images[i][m, n, 0] != 0)
|
||||
assert (data["image"][m, n, 0] == images[i][m, n, 0]) or\
|
||||
(data["image"][m, n, 0] == images[i][m, n, 0] + 1) or\
|
||||
(data["image"][m, n, 0] + 1 == images[i][m, n, 0])
|
||||
np.testing.assert_array_equal(data["label"], labels[i])
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
|
||||
images, labels = load_usps(DATA_DIR, "test")
|
||||
num_iter = 0
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
for m in range(16):
|
||||
for n in range(16):
|
||||
if (data["image"][m, n, 0] == 0 and images[i][m, n, 0] == 255) or\
|
||||
(data["image"][m, n, 0] == 255 and images[i][m, n, 0] == 0):
|
||||
assert False
|
||||
if (data["image"][m, n, 0] != images[i][m, n, 0]) and\
|
||||
(data["image"][m, n, 0] != images[i][m, n, 0] + 1) and\
|
||||
(data["image"][m, n, 0] + 1 != images[i][m, n, 0]):
|
||||
assert False
|
||||
np.testing.assert_array_equal(data["label"], labels[i])
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_usps_basic():
|
||||
"""
|
||||
Validate USPSDataset
|
||||
"""
|
||||
logger.info("Test USPSDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train")
|
||||
num_iter = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
test_data = ds.USPSDataset(DATA_DIR, "test")
|
||||
num_iter = 0
|
||||
for _ in test_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
# case 2: test num_samples
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
|
||||
num_iter = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
# case 3: test repeat
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
|
||||
train_data = train_data.repeat(5)
|
||||
num_iter = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
# case 4: test batch with drop_remainder=False
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
|
||||
assert train_data.get_dataset_size() == 3
|
||||
assert train_data.get_batch_size() == 1
|
||||
train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False
|
||||
assert train_data.get_batch_size() == 2
|
||||
assert train_data.get_dataset_size() == 2
|
||||
|
||||
num_iter = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
# case 5: test batch with drop_remainder=True
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
|
||||
assert train_data.get_dataset_size() == 3
|
||||
assert train_data.get_batch_size() == 1
|
||||
train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert train_data.get_dataset_size() == 1
|
||||
assert train_data.get_batch_size() == 2
|
||||
num_iter = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 1
|
||||
|
||||
|
||||
def test_usps_exception():
|
||||
"""
|
||||
Test error cases for USPSDataset
|
||||
"""
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.USPSDataset(DATA_DIR, "train", num_shards=10)
|
||||
ds.USPSDataset(DATA_DIR, "test", num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.USPSDataset(DATA_DIR, "train", shard_id=0)
|
||||
ds.USPSDataset(DATA_DIR, "test", shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=-1)
|
||||
ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=5)
|
||||
ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id=5)
|
||||
ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0)
|
||||
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256)
|
||||
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2)
|
||||
ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id="0")
|
||||
ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id="0")
|
||||
|
||||
error_msg_8 = "invalid input shape"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train")
|
||||
train_data = train_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in train_data.__iter__():
|
||||
pass
|
||||
|
||||
test_data = ds.USPSDataset(DATA_DIR, "test")
|
||||
test_data = test_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in test_data.__iter__():
|
||||
pass
|
||||
|
||||
error_msg_9 = "failed to find USPS train data file"
|
||||
with pytest.raises(RuntimeError, match=error_msg_9):
|
||||
train_data = ds.USPSDataset(WRONG_DIR, "train")
|
||||
for _ in train_data.__iter__():
|
||||
pass
|
||||
error_msg_10 = "failed to find USPS test data file"
|
||||
with pytest.raises(RuntimeError, match=error_msg_10):
|
||||
test_data = ds.USPSDataset(WRONG_DIR, "test")
|
||||
for _ in test_data.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_usps_visualize(plot=False):
|
||||
"""
|
||||
Visualize USPSDataset results
|
||||
"""
|
||||
logger.info("Test USPSDataset visualization")
|
||||
|
||||
train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, label_list = [], []
|
||||
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
label = item["label"]
|
||||
image_list.append(image)
|
||||
label_list.append("label {}".format(label))
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.shape == (16, 16, 1)
|
||||
assert image.dtype == np.uint8
|
||||
assert label.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
if plot:
|
||||
visualize_dataset(image_list, label_list)
|
||||
|
||||
test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, label_list = [], []
|
||||
for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
label = item["label"]
|
||||
image_list.append(image)
|
||||
label_list.append("label {}".format(label))
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.shape == (16, 16, 1)
|
||||
assert image.dtype == np.uint8
|
||||
assert label.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
if plot:
|
||||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_usps_usage():
|
||||
"""
|
||||
Validate USPSDataset image readings
|
||||
"""
|
||||
logger.info("Test USPSDataset usage flag")
|
||||
|
||||
def test_config(usage, path=None):
|
||||
path = DATA_DIR if path is None else path
|
||||
try:
|
||||
data = ds.USPSDataset(path, usage=usage, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("train") == 3
|
||||
assert test_config("test") == 3
|
||||
|
||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
# change this directory to the folder that contains all USPS files
|
||||
all_files_path = None
|
||||
# the following tests on the entire datasets
|
||||
if all_files_path is not None:
|
||||
assert test_config("train", all_files_path) == 3
|
||||
assert test_config("test", all_files_path) == 3
|
||||
assert ds.USPSDataset(all_files_path, usage="train").get_dataset_size() == 3
|
||||
assert ds.USPSDataset(all_files_path, usage="test").get_dataset_size() == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_usps_content_check()
|
||||
test_usps_basic()
|
||||
test_usps_exception()
|
||||
test_usps_visualize(plot=True)
|
||||
test_usps_usage()
|
Loading…
Reference in New Issue