From feff8899ac55493926082b3a8c1f09a93491e8f0 Mon Sep 17 00:00:00 2001 From: liyong Date: Wed, 27 May 2020 10:50:29 +0800 Subject: [PATCH] support padding samples --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 28 +- .../ccsrc/dataset/api/python_bindings.cc | 24 +- .../engine/datasetops/source/mindrecord_op.cc | 132 ++++-- .../engine/datasetops/source/mindrecord_op.h | 27 +- .../ccsrc/dataset/engine/gnn/graph_loader.cc | 6 +- .../mindrecord/include/common/shard_utils.h | 4 + .../ccsrc/mindrecord/include/shard_column.h | 6 +- .../include/shard_distributed_sample.h | 47 ++ .../ccsrc/mindrecord/include/shard_reader.h | 13 +- .../ccsrc/mindrecord/include/shard_sample.h | 10 +- .../ccsrc/mindrecord/include/shard_task.h | 11 +- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 72 ++- .../ccsrc/mindrecord/meta/shard_category.cc | 2 +- .../ccsrc/mindrecord/meta/shard_column.cc | 19 + .../meta/shard_distributed_sample.cc | 64 +++ .../ccsrc/mindrecord/meta/shard_sample.cc | 21 +- mindspore/ccsrc/mindrecord/meta/shard_task.cc | 20 +- mindspore/dataset/engine/datasets.py | 38 +- mindspore/dataset/engine/validators.py | 29 +- .../cpp/mindrecord/ut_shard_operator_test.cc | 3 - tests/ut/python/dataset/test_minddataset.py | 20 +- .../python/dataset/test_minddataset_padded.py | 444 ++++++++++++++++++ 22 files changed, 893 insertions(+), 147 deletions(-) create mode 100644 mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h create mode 100644 mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc create mode 100644 tests/ut/python/dataset/test_minddataset_padded.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 9b87f044f59..ede3f1f8d18 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -32,6 +32,7 @@ #include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/filter_op.h" #include "mindrecord/include/shard_category.h" +#include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" #include "dataset/util/random.h" @@ -400,7 +401,7 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto RETURN_STATUS_UNEXPECTED(err_msg); } - constexpr int kMaxPartitions = 64; + constexpr int kMaxPartitions = 1024; if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) { std::string err_msg = "Error: partitions is invalid or not set."; RETURN_STATUS_UNEXPECTED(err_msg); @@ -438,6 +439,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptrSetColumnsToLoad(in_col_names); } + if (!args["padded_sample"].is_none()) { + (void)builder->SetPaddedSample(args["padded_sample"]); + (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); + } std::vector> operators; for (auto arg : args) { std::string key = py::str(arg.first); @@ -447,14 +452,15 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptrSetNumMindRecordWorkers(ToInt(value)); } else if (key == "block_reader" && ToBool(value) == true) { (void)builder->SetBlockReader(); - } else if (key == "global_shuffle" && ToBool(value) == true) { - uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0; + } else if (key == "shuffle_option" && ToBool(value) == true) { + if (!args["partitions"].is_none()) continue; + uint32_t seed = GetSeed(); operators.push_back(std::make_shared(seed)); } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("_create_for_minddataset"); - std::shared_ptr sample_op = - create().cast>(); - operators.push_back(sample_op); + auto sampler = py::reinterpret_borrow(value); + auto create = sampler.attr("_create_for_minddataset"); + auto op = create().cast>(); + operators.push_back(op); } } } @@ -465,7 +471,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr(1, in_partitions[0], in_partitions[1])); + auto shuffle = ToBool(args["shuffle_option"]); + int num_padded = 0; + if (!args["num_padded"].is_none()) { + num_padded = ToInt(args["num_padded"]); + } + operators.push_back( + std::make_shared(in_partitions[0], in_partitions[1], num_padded, shuffle, 0)); } if (!operators.empty()) { diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 98bc59a7e6b..b1734eaa2be 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -66,6 +66,7 @@ #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" #include "mindrecord/include/shard_pk_sample.h" +#include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -157,17 +158,17 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", - [](const std::vector &paths, bool load_dataset, const py::object &sampler) { - int64_t count = 0; - std::shared_ptr op; - if (py::hasattr(sampler, "_create_for_minddataset")) { - auto create = sampler.attr("_create_for_minddataset"); - op = create().cast>(); - } - THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count)); - return count; - }); + .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, + const int64_t num_padded) { + int64_t count = 0; + std::shared_ptr op; + if (py::hasattr(sampler, "_create_for_minddataset")) { + auto create = sampler.attr("_create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); + return count; + }); (void)py::class_>(*m, "ManifestOp") .def_static("get_num_rows_and_classes", @@ -472,6 +473,7 @@ void bindSamplerOps(py::module *m) { (void)py::class_>( *m, "MindrecordSubsetRandomSampler") .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + (void)py::class_>( *m, "MindrecordPkSampler") .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 52899873382..c89d6cba3dc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -53,6 +53,8 @@ MindRecordOp::Builder::Builder() : build_dataset_file_({}) { build_op_connector_queue_size_ = cfg->op_connector_size(); build_block_reader_ = false; builder_num_workers_ = 0; + build_num_padded_ = 0; + build_sample_ = nullptr; } // The builder "build" method creates the final object. @@ -63,24 +65,57 @@ Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Building a MindRecordOp that has not provided a file."); } - + mindrecord::json sample_json; + if (build_num_padded_ > 0) { + sample_json = ToJson(build_sample_); + } new_mind_record_op = std::make_shared( build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, - build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_); + build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, + sample_json, build_sample_bytes_); RETURN_IF_NOT_OK(new_mind_record_op->Init()); - *ptr = std::move(new_mind_record_op); return Status::OK(); } Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } +mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { + if (obj.is_none()) { + return nullptr; + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { // also catch py::bytes + return obj.cast(); + } + if (py::isinstance(obj)) { + auto out = mindrecord::json::object(); + for (const py::handle &key : obj) { + if (py::isinstance(obj[key])) { + build_sample_bytes_[py::str(key).cast()] = obj[key].cast(); + } else { + out[py::str(key).cast()] = ToJson(obj[key]); + } + } + return out; + } + MS_LOG(ERROR) << "Python object convert to json failed, object is: " << py::cast(obj); + return mindrecord::json(); +} + // Constructor of the MindRecordOp. MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader) + const std::vector> &operators, const bool &block_reader, + int64_t num_padded, const mindrecord::json &sample_json, + const std::map &sample_bytes) : ParallelOp(num_mind_record_workers, op_connector_queue_size), rows_per_buffer_(rows_per_buffer), dataset_file_(dataset_file), @@ -92,7 +127,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf buffers_needed_(0), buf_cnt_(0), ended_worker_(0), - buffer_water_mark_(0) { + buffer_water_mark_(0), + num_padded_(num_padded), + sample_json_(sample_json), + sample_bytes_(sample_bytes) { io_blk_queues_.Init(num_workers_, op_connector_queue_size); if (!block_reader_) return; for (int32_t i = 0; i < num_workers_; ++i) { @@ -104,7 +142,7 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf Status MindRecordOp::Init() { shard_reader_ = std::make_unique(); auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, - block_reader_); + block_reader_, num_padded_); CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); @@ -161,10 +199,6 @@ Status MindRecordOp::Init() { column_name_id_map_[columns_to_load_[i]] = i; } - num_rows_ = shard_reader_->GetNumRows(); - // Compute how many buffers we would need to accomplish rowsPerBuffer - buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; - return Status::OK(); } @@ -261,20 +295,30 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_bu std::unique_ptr tensor_table = std::make_unique(); for (int32_t i = 0; i < rows_per_buffer_; ++i) { ShardTuple tupled_buffer; + mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; if (block_reader_) { if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); } else { int32_t row_id = buffer_id * rows_per_buffer_ + i; - tupled_buffer = shard_reader_->GetNextById(row_id, worker_id); + auto rc = shard_reader_->GetNextById(row_id, worker_id); + task_type = rc.first; + tupled_buffer = rc.second; + if (task_type == mindrecord::TaskType::kPaddedTask) { + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); + tensor_table->push_back(std::move(tensor_row)); + } if (tupled_buffer.empty()) break; } - for (const auto &tupled_row : tupled_buffer) { - std::vector columns_blob = std::get<0>(tupled_row); - mindrecord::json columns_json = std::get<1>(tupled_row); - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json)); - tensor_table->push_back(std::move(tensor_row)); + if (task_type == mindrecord::TaskType::kCommonTask) { + for (const auto &tupled_row : tupled_buffer) { + std::vector columns_blob = std::get<0>(tupled_row); + mindrecord::json columns_json = std::get<1>(tupled_row); + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json, task_type)); + tensor_table->push_back(std::move(tensor_row)); + } } } @@ -284,7 +328,7 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_bu } Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json) { + const mindrecord::json &columns_json, const mindrecord::TaskType task_type) { for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { auto column_name = columns_to_load_[i_col]; @@ -297,11 +341,39 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector column_shape; // Get column data - auto has_column = shard_reader_->GetShardColumn()->GetColumnValueByName( - column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size, - &column_shape); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); + auto shard_column = shard_reader_->GetShardColumn(); + if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { + auto rc = + shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); + if (rc.first != MSRStatus::SUCCESS) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data type."); + } + if (rc.second == mindrecord::ColumnInRaw) { + auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve raw data from padding sample."); + } + } else if (rc.second == mindrecord::ColumnInBlob) { + if (sample_bytes_.find(column_name) == sample_bytes_.end()) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve blob data from padding sample."); + } + std::string ss(sample_bytes_[column_name]); + n_bytes = ss.size(); + data_ptr = std::make_unique(n_bytes); + std::copy(ss.begin(), ss.end(), data_ptr.get()); + } else { + RETURN_STATUS_UNEXPECTED("Retrieved data type is unknown."); + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + } else { + auto has_column = + shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, + &column_data_type, &column_data_type_size, &column_shape); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); + } } std::shared_ptr tensor; @@ -334,7 +406,8 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { } for (int32_t i = 0; i < rows_per_buffer_; i++) { // Block reader does NOT care about argument - ShardTuple tuple_buffer = shard_reader_->GetNextById(i, i); + auto rc = shard_reader_->GetNextById(i, i); + ShardTuple tuple_buffer = rc.second; if (tuple_buffer.empty()) break; block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); } @@ -348,11 +421,8 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { Status MindRecordOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); num_rows_ = shard_reader_->GetNumRows(); - - buffers_needed_ = num_rows_ / rows_per_buffer_; - if (num_rows_ % rows_per_buffer_ != 0) { - buffers_needed_++; - } + // Compute how many buffers we would need to accomplish rowsPerBuffer + buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; while (true) { // each iterator is 1 epoch for (int32_t i = 0; i < buffers_needed_; ++i) { @@ -417,9 +487,9 @@ Status MindRecordOp::LaunchThreadAndInitOp() { } Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count) { + const std::shared_ptr &op, int64_t *count, int64_t num_padded) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); if (rc == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index 251b4f91302..77b25139303 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -104,10 +104,22 @@ class MindRecordOp : public ParallelOp { return *this; } + Builder &SetNumToPadSamples(int64_t num_padded) { + build_num_padded_ = num_padded; + return *this; + } + + Builder &SetPaddedSample(const py::handle &sample) { + build_sample_ = sample; + return *this; + } + Status SanityCheck() const; static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } + mindrecord::json ToJson(const py::handle &obj); + private: static constexpr int32_t kDefaultMindRecordWorkers = 4; // The builder saves all MindRecordOp construction arguments internally. @@ -121,6 +133,9 @@ class MindRecordOp : public ParallelOp { std::vector build_columns_to_load_; std::vector> build_operators_; bool build_block_reader_; + int64_t build_num_padded_; + py::handle build_sample_; + std::map build_sample_bytes_; }; // Constructor of the MindRecordOp. @@ -133,7 +148,9 @@ class MindRecordOp : public ParallelOp { // @param operators - ShardOperators for Shuffle, Category, Sample MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader); + const std::vector> &operators, const bool &block_reader, + int64_t num_padded_, const mindrecord::json &sample_json, + const std::map &sample_bytes_); // Destructor ~MindRecordOp() override; @@ -178,7 +195,7 @@ class MindRecordOp : public ParallelOp { int32_t num_rows() const { return num_rows_; } static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count); + const std::shared_ptr &op, int64_t *count, int64_t num_padded); // Getter method int32_t rows_per_buffer() const { return rows_per_buffer_; } @@ -209,7 +226,7 @@ class MindRecordOp : public ParallelOp { // @param columns_blob - the blob data received from the reader // @param columns_json - the data for fields received from the reader Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json); + const mindrecord::json &columns_json, const mindrecord::TaskType task_type); Status FetchBlockBuffer(const int32_t &buffer_id); @@ -226,6 +243,10 @@ class MindRecordOp : public ParallelOp { std::atomic ended_worker_; std::atomic buffer_water_mark_; + int64_t num_padded_; + mindrecord::json sample_json_; + std::map sample_bytes_; + std::unique_ptr data_schema_; // Data schema for column typing std::vector columns_blob_; // Blob Columns to load from dataset std::vector columns_blob_index_; // Blob Columns to load from dataset diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc index c517fda969e..9e5cbbb7889 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc @@ -203,7 +203,8 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vectorPost(); - ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id); + auto ret = shard_reader_->GetNextById(row_id_++, worker_id); + ShardTuple rows = ret.second; while (rows.empty() == false) { RETURN_IF_INTERRUPTED(); for (const auto &tupled_row : rows) { @@ -224,7 +225,8 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; } } - rows = shard_reader_->GetNextById(row_id_++, worker_id); + auto rc = shard_reader_->GetNextById(row_id_++, worker_id); + rows = rc.second; } return Status::OK(); } diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index 65a8d53e72c..8aa5bdfbda4 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -73,6 +73,10 @@ enum ShardType { kCV = 1, }; +enum TaskType { + kCommonTask = 0, + kPaddedTask = 1, +}; enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; enum ShuffleType { kShuffleCategory, kShuffleSample }; diff --git a/mindspore/ccsrc/mindrecord/include/shard_column.h b/mindspore/ccsrc/mindrecord/include/shard_column.h index 496e7ec3ea3..ec71fd5bd37 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_column.h +++ b/mindspore/ccsrc/mindrecord/include/shard_column.h @@ -89,12 +89,16 @@ class ShardColumn { MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, const unsigned char **data, std::unique_ptr *data_ptr, uint64_t *n_bytes); + std::pair GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape); - private: /// \brief get column value from json MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, std::unique_ptr *data_ptr, uint64_t *n_bytes); + private: /// \brief get float value from json template MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); diff --git a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h new file mode 100644 index 00000000000..c962c869d0a --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ + +#include +#include +#include +#include +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_shuffle.h" +#include "mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardDistributedSample : public ShardSample { + public: + ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + + ~ShardDistributedSample() override{}; + + MSRStatus PreExecute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + bool shuffle_; + int no_of_padded_samples_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 8db7761fb85..9be017c6467 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -58,7 +58,8 @@ using ROW_GROUPS = std::tuple>>, std::vector>>; using ROW_GROUP_BRIEF = std::tuple>, std::vector>; -using TASK_RETURN_CONTENT = std::pair, json>>>; +using TASK_RETURN_CONTENT = + std::pair, json>>>>; const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode const int kNumPageInBuffer = 16; // page buffer size in block-reader mode @@ -78,7 +79,8 @@ class ShardReader { /// \return MSRStatus the status of MSRStatus MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const bool &block_reader = false); + const std::vector> &operators = {}, const bool &block_reader = false, + const int num_padded = 0); /// \brief open files and initialize reader, python API /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list @@ -127,7 +129,7 @@ class ShardReader { /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count); + const std::shared_ptr &op, int64_t *count, const int num_padded); /// \brief shuffle task with incremental seed /// \return void @@ -182,7 +184,8 @@ class ShardReader { /// \brief return a row by id /// \return a batch of images and image data - std::vector, json>> GetNextById(const int64_t &task_id, const int32_t &consumer_id); + std::pair, json>>> GetNextById(const int64_t &task_id, + const int32_t &consumer_id); /// \brief return a batch in block-reader mode, given that one is ready /// \return a batch of images and image data @@ -330,6 +333,8 @@ class ShardReader { bool all_in_index_ = true; // if all columns are stored in index-table bool interrupt_ = false; // reader interrupted + int num_padded_; // number of padding samples + // Delivery/Iterator mode begin const std::string kThreadName = "THRD_ITER_"; // prefix of thread name std::vector thread_set_; // thread list diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index 7905f328f96..111df3bc1aa 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -38,22 +38,22 @@ class ShardSample : public ShardOperator { ~ShardSample() override{}; - const std::pair GetPartitions() const; - MSRStatus Execute(ShardTask &tasks) override; MSRStatus SufExecute(ShardTask &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - private: + protected: int numerator_; int denominator_; - int no_of_samples_; int partition_id_; + std::shared_ptr shuffle_op_; + + private: + int no_of_samples_; std::vector indices_; SamplerType sampler_type_; - std::shared_ptr shuffle_op_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h index d48c25c9cd6..9b8ac54a467 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/mindrecord/include/shard_task.h @@ -29,9 +29,10 @@ class ShardTask { public: void MakePerm(); - void InsertTask(int shard_id, int group_id, const std::vector &offset, const json &label); + void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label); - void InsertTask(std::tuple, std::vector, json> task); + void InsertTask(std::tuple, std::vector, json> task); void PopBack(); @@ -39,15 +40,15 @@ class ShardTask { uint32_t SizeOfRows() const; - std::tuple, std::vector, json> &GetTaskByID(size_t id); + std::tuple, std::vector, json> &GetTaskByID(size_t id); - std::tuple, std::vector, json> &GetRandomTask(); + std::tuple, std::vector, json> &GetRandomTask(); static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); uint32_t categories = 1; - std::vector, std::vector, json>> task_list_; + std::vector, std::vector, json>> task_list_; std::vector permutation_; }; } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index fcb588fff83..e5b18c1f9c3 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -45,6 +45,7 @@ ShardReader::ShardReader() { row_id_ = 0; num_blocks_ = 0; block_reader_ = false; + num_padded_ = 0; } std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { @@ -790,7 +791,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { } MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count) { + const std::shared_ptr &op, int64_t *count, const int num_padded) { if (SUCCESS != Init(file_paths, load_dataset)) { return FAILED; } @@ -802,11 +803,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths num_samples = category_op->GetNumSamples(num_rows_, num_classes); } else if (std::dynamic_pointer_cast(op)) { num_samples = op->GetNumSamples(num_rows_, 0); + if (-1 == num_samples) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; + return FAILED; + } } else { - } - if (-1 == num_samples) { - MS_LOG(ERROR) << "Failed to get dataset size."; - return FAILED; + if (num_padded > 0) num_samples += num_padded; } *count = num_samples; return SUCCESS; @@ -814,7 +816,8 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, const std::vector &selected_columns, - const std::vector> &operators, const bool &block_reader) { + const std::vector> &operators, const bool &block_reader, + int num_padded) { // Open file and set header by ShardReader auto ret = Init(file_paths, load_dataset); if (SUCCESS != ret) { @@ -844,6 +847,7 @@ MSRStatus ShardReader::Open(const std::vector &file_paths, bool loa // Initialize argument shard_count_ = static_cast(file_paths_.size()); n_consumer_ = n_consumer; + num_padded_ = num_padded; operators_ = operators; @@ -935,7 +939,7 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector(rg); auto group_id = std::get<1>(rg); auto n_Rows = std::get<3>(rg); - tasks_.InsertTask(shard_id, group_id, std::vector{n_Rows}, json{}); + tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); } return SUCCESS; } @@ -986,7 +990,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector(details)[iStart], + categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], std::get<5>(details)[iStart]); category_index++; } @@ -1014,7 +1018,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, local_columns[shard_id][i]); } @@ -1044,6 +1048,11 @@ MSRStatus ShardReader::CreateTasks(const std::vector 0) { + for (int i = 0; i < num_padded_; ++i) { + tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); + } + } } else { if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { return FAILED; @@ -1070,18 +1079,27 @@ MSRStatus ShardReader::CreateTasks(const std::vector= static_cast(tasks_.Size())) { - return std::make_pair(FAILED, std::vector, json>>()); + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); } // Pick up task from task list auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - auto shard_id = std::get<0>(std::get<0>(task)); - auto group_id = std::get<1>(std::get<0>(task)); - auto addr = std::get<1>(task); + // check task type + auto task_type = std::get<0>(task); + if (task_type == TaskType::kPaddedTask) { + return std::make_pair(SUCCESS, + std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); + } + + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); + auto addr = std::get<2>(task); const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); if (SUCCESS != ret.first) { - return std::make_pair(FAILED, std::vector, json>>()); + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); } const std::shared_ptr &page = ret.second; @@ -1093,7 +1111,8 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { MS_LOG(ERROR) << "File seekg failed"; file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, std::vector, json>>()); + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); } auto &io_read = @@ -1101,14 +1120,15 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ if (!io_read.good() || io_read.fail() || io_read.bad()) { MS_LOG(ERROR) << "File read failed"; file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, std::vector, json>>()); + return std::make_pair(FAILED, + std::pair(TaskType::kCommonTask, std::vector, json>>())); } // Deliver batch data to output map std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(std::get<2>(task))); + batch.emplace_back(std::move(images), std::move(std::get<3>(task))); - return std::make_pair(SUCCESS, std::move(batch)); + return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); } MSRStatus ShardReader::ConsumerByRow(int consumer_id) { @@ -1133,7 +1153,7 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { if (SUCCESS != ret.first) { return FAILED; } - const auto &batch = ret.second; + const auto &batch = (ret.second).second; // Hanging if maximum map size exceeded // otherwise, set batch data in map { @@ -1193,8 +1213,8 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { // Pick up task from task list auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - auto shard_id = std::get<0>(std::get<0>(task)); - auto group_id = std::get<1>(std::get<0>(task)); + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); if (SUCCESS != std::get<0>(row_group_brief)) { return FAILED; @@ -1302,17 +1322,17 @@ std::vector, json>> ShardReader::GetNext() { return *res; } -std::vector, json>> ShardReader::GetNextById(const int64_t &task_id, - const int32_t &consumer_id) { +std::pair, json>>> ShardReader::GetNextById( + const int64_t &task_id, const int32_t &consumer_id) { if (interrupt_) { - return std::vector, json>>(); + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); } if (block_reader_) { - return GetBlockNext(); + return std::make_pair(TaskType::kCommonTask, GetBlockNext()); } const auto &ret = ConsumerOneTask(task_id, consumer_id); if (SUCCESS != ret.first) { - return std::vector, json>>(); + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); } return std::move(ret.second); } diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index dfca92a08c1..bd427a330a7 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -41,7 +41,7 @@ int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { return std::min(num_categories_, num_classes) * num_elements_; } - return -1; + return 0; } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/mindrecord/meta/shard_column.cc index 86ad0c96d7b..8c3e8317888 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_column.cc @@ -66,6 +66,25 @@ ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool num_blob_column_ = blob_column_.size(); } +std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return {FAILED, ColumnNotFound}; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + return {SUCCESS, column_category}; +} + MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, const json &columns_json, const unsigned char **data, std::unique_ptr *data_ptr, uint64_t *n_bytes, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc new file mode 100644 index 00000000000..2b7a661c06e --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_distributed_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, + uint32_t seed) + : ShardSample(1, num_shards, shard_id), shuffle_(shuffle), no_of_padded_samples_(no_of_padded_samples) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); +} + +int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_padded_samples_ <= 0) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } else { + auto padded_size = dataset_size + no_of_padded_samples_; + if (padded_size % denominator_ == 0) { + return padded_size / denominator_ * numerator_; + } else { + return -1; + } + } + return 0; +} +MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { + auto total_no = tasks.Size(); + if (no_of_padded_samples_ > 0) { + if (total_no % denominator_ != 0) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; + return FAILED; + } + } + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index d7842a11a3b..c207747194a 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -25,32 +25,32 @@ namespace mindrecord { ShardSample::ShardSample(int n) : numerator_(0), denominator_(0), - no_of_samples_(n), partition_id_(0), + no_of_samples_(n), indices_({}), sampler_type_(kCustomTopNSampler) {} ShardSample::ShardSample(int num, int den) : numerator_(num), denominator_(den), - no_of_samples_(0), partition_id_(0), + no_of_samples_(0), indices_({}), sampler_type_(kCustomTopPercentSampler) {} ShardSample::ShardSample(int num, int den, int par) : numerator_(num), denominator_(den), - no_of_samples_(0), partition_id_(par), + no_of_samples_(0), indices_({}), sampler_type_(kCustomTopPercentSampler) {} ShardSample::ShardSample(const std::vector &indices, uint32_t seed) : numerator_(0), denominator_(0), - no_of_samples_(0), partition_id_(0), + no_of_samples_(0), indices_(indices), sampler_type_(kSubsetRandomSampler) { shuffle_op_ = std::make_shared(seed); @@ -71,19 +71,12 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (sampler_type_ == kSubsetRandomSampler) { return indices_.size(); } - return -1; -} - -const std::pair ShardSample::GetPartitions() const { - if (numerator_ == 1 && denominator_ > 1) { - return std::pair(denominator_, partition_id_); - } - return std::pair(-1, -1); + return 0; } MSRStatus ShardSample::Execute(ShardTask &tasks) { int no_of_categories = static_cast(tasks.categories); - int total_no = static_cast(tasks.Size()); + int total_no = static_cast(tasks.Size()); // make sure task_size int taking = 0; if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 @@ -97,7 +90,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { } else { // constructor TopPercent if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { if (numerator_ == 1 && denominator_ > 1) { // sharding - taking = (total_no / denominator_) + (total_no % denominator_ == 0 ? 0 : 1); + taking = (total_no + denominator_ - 1) / denominator_; } else { // non sharding taking = total_no * numerator_ / denominator_; taking -= (taking % no_of_categories); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index 3abc725a7b3..0a8d8e3d43b 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -31,16 +31,18 @@ void ShardTask::MakePerm() { } } -void ShardTask::InsertTask(int shard_id, int group_id, const std::vector &offset, const json &label) { +void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label) { MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(std::make_tuple(shard_id, group_id), offset, label); + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); } -void ShardTask::InsertTask(std::tuple, std::vector, json> task) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<0>(task)) - << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() +void ShardTask::InsertTask(std::tuple, std::vector, json> task) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) + << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_.push_back(std::move(task)); } @@ -52,19 +54,19 @@ uint32_t ShardTask::SizeOfRows() const { if (task_list_.size() == 0) return static_cast(0); // 1 task is 1 page - auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { - return x + std::get<1>(y)[0]; + auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { + return x + std::get<2>(y)[0]; }; uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); return nRows; } -std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { +std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { MS_ASSERT(id < task_list_.size()); return task_list_[id]; } -std::tuple, std::vector, json> &ShardTask::GetRandomTask() { +std::tuple, std::vector, json> &ShardTask::GetRandomTask() { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(0, task_list_.size() - 1); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 0afb6ce6b0a..a5d76fc4715 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2548,7 +2548,11 @@ class MindDataset(SourceDataset): sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, sampler is exclusive with shuffle and block_reader). Support list: SubsetRandomSampler, - PkSampler + PkSampler. + padded_sample (dict, optional): Samples will be appended to dataset, which + keys are the same as column_list. + num_padded (int, optional): Number of padding samples.Dataset size + plus num_padded should be divisible by num_shards. Raises: ValueError: If num_shards is specified but shard_id is None. @@ -2559,7 +2563,8 @@ class MindDataset(SourceDataset): @check_minddataset def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, - block_reader=False, sampler=None): + block_reader=False, sampler=None, padded_sample=None, + num_padded=None): super().__init__(num_parallel_workers) if isinstance(dataset_file, list): self.load_dataset = False @@ -2567,7 +2572,7 @@ class MindDataset(SourceDataset): self.load_dataset = True self.dataset_file = dataset_file self.columns_list = columns_list - self.global_shuffle = shuffle + self.shuffle_option = shuffle self.distribution = "" self.sampler = sampler @@ -2598,22 +2603,36 @@ class MindDataset(SourceDataset): raise ValueError("shuffle not allowed when use sampler") if block_reader is False and sampler is None: - self.global_shuffle = not bool(shuffle is False) + self.shuffle_option = not bool(shuffle is False) + + if num_padded is None: + num_padded = 0 self.num_shards = num_shards self.shard_id = shard_id self.block_reader = block_reader + self.padded_sample = padded_sample + self.num_padded = num_padded def get_args(self): args = super().get_args() + padded_sample = {} + if self.padded_sample: + for k, v in self.padded_sample.items(): + if isinstance(v, np.ndarray): + padded_sample[k] = v.tobytes() + else: + padded_sample[k] = v args["dataset_file"] = self.dataset_file args["load_dataset"] = self.load_dataset args["columns_list"] = self.columns_list - args["global_shuffle"] = self.global_shuffle + args["shuffle_option"] = self.shuffle_option args["partitions"] = self.partitions args["block_reader"] = self.block_reader args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["num_padded"] = self.num_padded + args["padded_sample"] = padded_sample args["sampler"] = self.sampler return args @@ -2628,19 +2647,22 @@ class MindDataset(SourceDataset): dataset_file = [self.dataset_file] else: dataset_file = self.dataset_file - num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler) + num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) if self.partitions is not None and self.partitions[0] > 0: if num_rows % self.partitions[0] == 0: num_rows = num_rows // self.partitions[0] else: + if self.num_padded > 0: + raise RuntimeError( + "Dataset size plus number of padded samples is not divisible by number of shards.") num_rows = num_rows // self.partitions[0] + 1 return num_rows def is_shuffled(self): - if self.global_shuffle is None: + if self.shuffle_option is None: return True - return self.global_shuffle or self.sampler.is_shuffled() + return self.shuffle_option or self.sampler.is_shuffled() def is_sharded(self): if self.num_shards is not None: diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4893aace361..1b01d738642 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -323,6 +323,27 @@ def check_sampler_shuffle_shard_options(param_dict): raise RuntimeError("shard_id is specified but num_shards is not.") +def check_padding_options(param_dict): + """ check for valid padded_sample and num_padded of padded samples""" + columns_list = param_dict.get('columns_list') + block_reader = param_dict.get('block_reader') + padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') + if padded_sample is not None: + if num_padded is None: + raise RuntimeError("padded_sample is specified and requires num_padded as well.") + if num_padded < 0: + raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) + if columns_list is None: + raise RuntimeError("padded_sample is specified and requires columns_list as well.") + for column in columns_list: + if column not in padded_sample: + raise ValueError("padded_sample cannot match columns_list.") + if block_reader: + raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") + + if padded_sample is None and num_padded is not None: + raise RuntimeError("num_padded is specified but padded_sample is not.") + def check_imagefolderdatasetv2(method): """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" @@ -549,9 +570,10 @@ def check_minddataset(method): def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) - nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id'] + nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] nreq_param_list = ['columns_list'] nreq_param_bool = ['block_reader'] + nreq_param_dict = ['padded_sample'] # check dataset_file; required argument dataset_file = param_dict.get('dataset_file') @@ -569,12 +591,11 @@ def check_minddataset(method): check_param_type(nreq_param_bool, param_dict, bool) - num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - if (num_shards is not None and shard_id is None) or (num_shards is None and shard_id is not None): - raise ValueError("num_shards and shard_id need to be set or not set at the same time") + check_param_type(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) + check_padding_options(param_dict) return method(*args, **kwargs) return new_method diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 23c2c1e34fd..7fe60c3bfa6 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -139,9 +139,6 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { const int kPar = 2; std::vector> ops; ops.push_back(std::make_shared(kNum, kDen, kPar)); - auto partitions = std::dynamic_pointer_cast(ops[0])->GetPartitions(); - ASSERT_TRUE(partitions.first == 4); - ASSERT_TRUE(partitions.second == 2); ShardReader dataset; dataset.Open({file_name}, true, 4, column_list, ops); diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 00af3fa660d..991bdf71a17 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -227,10 +227,9 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): num_shards=num_shards, shard_id=partition_id) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info( - "-------------- partition : {} ------------------------".format(partition_id)) - logger.info( - "-------------- item[label]: {} -----------------------".format(item["label"])) + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) num_iter += 1 return num_iter @@ -321,12 +320,11 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file): """issue 888 test.""" columns_list = ["data", "label"] num_readers = 2 - data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, - num_readers, shuffle=False, num_shards=5, shard_id=1) - data = data.shuffle(2) - data = data.repeat(9) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1) + data_set = data_set.shuffle(2) + data_set = data_set.repeat(9) num_iter = 0 - for _ in data.create_dict_iterator(): + for _ in data_set.create_dict_iterator(): num_iter += 1 assert num_iter == 18 @@ -335,8 +333,7 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): """tutorial for cv minddataset.""" columns_list = ["data", "label"] num_readers = 4 - data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, - block_reader=True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, block_reader=True) assert data_set.get_dataset_size() == 10 repeat_num = 2 data_set = data_set.repeat(repeat_num) @@ -544,7 +541,6 @@ def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): num_iter += 1 assert num_iter == 10 - def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file): """tutorial for nlp minderdataset.""" num_readers = 4 diff --git a/tests/ut/python/dataset/test_minddataset_padded.py b/tests/ut/python/dataset/test_minddataset_padded.py new file mode 100644 index 00000000000..8128855b244 --- /dev/null +++ b/tests/ut/python/dataset/test_minddataset_padded.py @@ -0,0 +1,444 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This is the test module for mindrecord +""" +import collections +import json +import numpy as np +import os +import pytest +import re +import string + +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore import log as logger +from mindspore.dataset.transforms.vision import Inter +from mindspore.mindrecord import FileWriter + +FILES_NUM = 4 +CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" +CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" +CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" +CV_DIR_NAME = "../data/mindrecord/testImageNetData" +NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" +NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" +NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt" + + +@pytest.fixture +def add_and_remove_cv_file(): + """add/remove cv file""" + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists( + "{}.db".format(x)) else None + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_cv_data" + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + +@pytest.fixture +def add_and_remove_nlp_file(): + """add/remove nlp file""" + paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(NLP_FILE_NAME, FILES_NUM) + data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] + nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, + "rating": {"type": "float32"}, + "input_ids": {"type": "int64", + "shape": [-1]}, + "input_mask": {"type": "int64", + "shape": [1, -1]}, + "segment_ids": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.add_index(["id", "rating"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_nlp_data" + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + +def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["label", "file_name", "data"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -1 + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, padded_sample=padded_sample, num_padded=5) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + num_padded_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + if item['label'] == -1: + num_padded_iter += 1 + assert item['file_name'] == bytes(padded_sample['file_name'], + encoding='utf8') + assert item['label'] == padded_sample['label'] + assert (item['data'] == np.array(list(padded_sample['data']))).all() + num_iter += 1 + assert num_padded_iter ==5 + assert num_iter == 15 + + +def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -2 + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded, dataset_size): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + assert data_set.get_dataset_size() == dataset_size + num_iter = 0 + num_padded_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + if item['label'] == -2: + num_padded_iter += 1 + assert item['file_name'] == bytes(padded_sample['file_name'], encoding='utf8') + assert item['label'] == padded_sample['label'] + assert (item['data'] == np.array(list(padded_sample['data']))).all() + num_iter += 1 + return num_iter + + assert partitions(4, 2, 3) == 3 + assert partitions(5, 5, 3) == 3 + assert partitions(9, 8, 2) == 2 + +def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -2 + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + return num_iter + + with pytest.raises(RuntimeError): + partitions(4, 1) + +def test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -2 + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + with pytest.raises(RuntimeError): + data_set.get_dataset_size() == 3 + partitions(4, 1) + +def test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample.pop('label', None) + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + with pytest.raises(Exception, match="padded_sample cannot match columns_list."): + partitions(4, 2) + +def test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -2 + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + with pytest.raises(Exception, match="padded_sample is specified and requires columns_list as well."): + partitions(4, 2) + +def test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample) + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + with pytest.raises(Exception, match="padded_sample is specified and requires num_padded as well."): + partitions(4, 2) + +def test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['file_name'] = 'dummy.jpg' + num_readers = 4 + + def partitions(num_shards, num_padded): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, + num_shards=num_shards, + shard_id=partition_id, + num_padded=num_padded) + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + with pytest.raises(Exception, match="num_padded is specified but padded_sample is not."): + partitions(4, 2) + + + +def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file): + columns_list = ["input_ids", "id", "rating"] + + data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] + padded_sample = data[0] + padded_sample['id'] = "-1" + padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64) + padded_sample['rating'] = 1.0 + num_readers = 4 + + def partitions(num_shards, num_padded, dataset_size): + for partition_id in range(num_shards): + data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, + padded_sample=padded_sample, + num_padded=num_padded) + assert data_set.get_dataset_size() == dataset_size + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- item[id]: {} ------------------------".format(item["id"])) + logger.info("-------------- item[rating]: {} --------------------".format(item["rating"])) + logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(item["input_ids"], item["input_ids"].shape)) + if item['id'] == '-1': + num_padded_iter += 1 + assert item['id'] == padded_sample['id'] + assert item['input_ids'] == padded_sample['input_ids'] + assert item['rating'] == padded_sample['rating'] + num_iter += 1 + return num_iter + + assert partitions(4, 6, 4) == 4 + assert partitions(5, 5, 3) == 3 + assert partitions(9, 8, 2) == 2 + +def get_data(dir_name): + """ + usage: get data from imagenet dataset + params: + dir_name: directory containing folder images and annotation information + + """ + if not os.path.isdir(dir_name): + raise IOError("Directory {} not exists".format(dir_name)) + img_dir = os.path.join(dir_name, "images") + ann_file = os.path.join(dir_name, "annotation.txt") + with open(ann_file, "r") as file_reader: + lines = file_reader.readlines() + + data_list = [] + for i, line in enumerate(lines): + try: + filename, label = line.split(",") + label = label.strip("\n") + with open(os.path.join(img_dir, filename), "rb") as file_reader: + img = file_reader.read() + data_json = {"id": i, + "file_name": filename, + "data": img, + "label": int(label)} + data_list.append(data_json) + except FileNotFoundError: + continue + return data_list + + +def get_nlp_data(dir_name, vocab_file, num): + """ + Return raw data of aclImdb dataset. + + Args: + dir_name (str): String of aclImdb dataset's path. + vocab_file (str): String of dictionary's path. + num (int): Number of sample. + + Returns: + List + """ + if not os.path.isdir(dir_name): + raise IOError("Directory {} not exists".format(dir_name)) + for root, dirs, files in os.walk(dir_name): + for index, file_name_extension in enumerate(files): + if index < num: + file_path = os.path.join(root, file_name_extension) + file_name, _ = file_name_extension.split('.', 1) + id_, rating = file_name.split('_', 1) + with open(file_path, 'r') as f: + raw_content = f.read() + + dictionary = load_vocab(vocab_file) + vectors = [dictionary.get('[CLS]')] + vectors += [dictionary.get(i) if i in dictionary + else dictionary.get('[UNK]') + for i in re.findall(r"[\w']+|[{}]" + .format(string.punctuation), + raw_content)] + vectors += [dictionary.get('[SEP]')] + input_, mask, segment = inputs(vectors) + input_ids = np.reshape(np.array(input_), [-1]) + input_mask = np.reshape(np.array(mask), [1, -1]) + segment_ids = np.reshape(np.array(segment), [2, -1]) + data = { + "label": 1, + "id": id_, + "rating": float(rating), + "input_ids": input_ids, + "input_mask": input_mask, + "segment_ids": segment_ids + } + yield data + + +def convert_to_uni(text): + if isinstance(text, str): + return text + if isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + raise Exception("The type %s does not convert!" % type(text)) + + +def load_vocab(vocab_file): + """load vocabulary to translate statement.""" + vocab = collections.OrderedDict() + vocab.setdefault('blank', 2) + index = 0 + with open(vocab_file) as reader: + while True: + tmp = reader.readline() + if not tmp: + break + token = convert_to_uni(tmp) + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def inputs(vectors, maxlen=50): + length = len(vectors) + if length > maxlen: + return vectors[0:maxlen], [1] * maxlen, [0] * maxlen + input_ = vectors + [0] * (maxlen - length) + mask = [1] * length + [0] * (maxlen - length) + segment = [0] * maxlen + return input_, mask, segment