From bb51bb88d780b75dd03de940c0aef3766499efc4 Mon Sep 17 00:00:00 2001 From: jonwe Date: Wed, 13 May 2020 15:59:01 -0400 Subject: [PATCH] add compress in mindrecord --- .../engine/datasetops/source/mindrecord_op.cc | 371 ++------------ .../engine/datasetops/source/mindrecord_op.h | 55 +- .../ccsrc/mindrecord/common/shard_pybind.cc | 4 +- .../mindrecord/include/common/shard_utils.h | 3 + .../ccsrc/mindrecord/include/shard_column.h | 163 ++++++ .../ccsrc/mindrecord/include/shard_header.h | 3 - .../ccsrc/mindrecord/include/shard_reader.h | 15 +- .../ccsrc/mindrecord/include/shard_writer.h | 4 +- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 109 ++-- mindspore/ccsrc/mindrecord/io/shard_writer.cc | 10 + .../ccsrc/mindrecord/meta/shard_column.cc | 473 ++++++++++++++++++ .../ccsrc/mindrecord/meta/shard_header.cc | 6 +- mindspore/mindrecord/shardutils.py | 35 +- .../testOldVersion/aclImdb.mindrecord0 | Bin 0 -> 49216 bytes .../testOldVersion/aclImdb.mindrecord0.db | Bin 0 -> 16384 bytes .../testOldVersion/aclImdb.mindrecord1 | Bin 0 -> 49216 bytes .../testOldVersion/aclImdb.mindrecord1.db | Bin 0 -> 16384 bytes .../testOldVersion/aclImdb.mindrecord2 | Bin 0 -> 49216 bytes .../testOldVersion/aclImdb.mindrecord2.db | Bin 0 -> 16384 bytes .../testOldVersion/aclImdb.mindrecord3 | Bin 0 -> 49216 bytes .../testOldVersion/aclImdb.mindrecord3.db | Bin 0 -> 16384 bytes tests/ut/python/dataset/test_minddataset.py | 384 ++++++++++---- .../mindrecord/test_cifar100_to_mindrecord.py | 51 +- .../mindrecord/test_cifar10_to_mindrecord.py | 74 +-- .../mindrecord/test_imagenet_to_mindrecord.py | 40 +- .../mindrecord/test_mindrecord_exception.py | 120 ++--- .../ut/python/mindrecord/test_mnist_to_mr.py | 52 +- 27 files changed, 1227 insertions(+), 745 deletions(-) create mode 100644 mindspore/ccsrc/mindrecord/include/shard_column.h create mode 100644 mindspore/ccsrc/mindrecord/meta/shard_column.cc create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0 create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0.db create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1 create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2 create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2.db create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3 create mode 100644 tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 991869ac081..e7ed0e12a3f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -112,25 +112,26 @@ Status MindRecordOp::Init() { data_schema_ = std::make_unique(); - std::vector> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); - // check whether schema exists, if so use the first one - CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); - mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; + std::vector col_names = shard_reader_->get_shard_column()->GetColumnName(); + CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); + std::vector col_data_types = shard_reader_->get_shard_column()->GeColumnDataType(); + std::vector> col_shapes = shard_reader_->get_shard_column()->GetColumnShape(); bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything std::map colname_to_ind; - for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) { - std::string colname = it.key(); // key of the json, column name - mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape + for (uint32_t i = 0; i < col_names.size(); i++) { + std::string colname = col_names[i]; ColDescriptor col_desc; + TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown - std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"]; + std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} - if (it_value["type"] == "bytes") { // rank = 1 + + if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1 col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); - } else if (it_value.find("shape") != it_value.end()) { - std::vector vec(it_value["shape"].size()); // temporary vector to hold shape - (void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); + } else if (col_shapes[i].size() > 0) { + std::vector vec(col_shapes[i].size()); // temporary vector to hold shape + (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); t_shape = TensorShape(vec); col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); } else { // unknown shape @@ -162,33 +163,10 @@ Status MindRecordOp::Init() { num_rows_ = shard_reader_->GetNumRows(); // Compute how many buffers we would need to accomplish rowsPerBuffer buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; - RETURN_IF_NOT_OK(SetColumnsBlob()); return Status::OK(); } -Status MindRecordOp::SetColumnsBlob() { - columns_blob_ = shard_reader_->GetBlobFields().second; - - // get the exactly blob fields by columns_to_load_ - std::vector columns_blob_exact; - for (auto &blob_field : columns_blob_) { - for (auto &column : columns_to_load_) { - if (column.compare(blob_field) == 0) { - columns_blob_exact.push_back(blob_field); - break; - } - } - } - - columns_blob_index_ = std::vector(columns_to_load_.size(), -1); - int32_t iBlob = 0; - for (auto &blob_exact : columns_blob_exact) { - columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++; - } - return Status::OK(); -} - // Destructor MindRecordOp::~MindRecordOp() {} @@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { } } -template -Status MindRecordOp::LoadFeature(std::shared_ptr *tensor, int32_t i_col, - const std::vector &columns_blob, const mindrecord::json &columns_json) const { - TensorShape new_shape = TensorShape::CreateUnknownRankShape(); - const unsigned char *data = nullptr; - - std::unique_ptr array_data; - std::string string_data; - - const ColDescriptor &cur_column = data_schema_->column(i_col); - std::string column_name = columns_to_load_[i_col]; - DataType type = cur_column.type(); - - // load blob column - if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) { - int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col]; - RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column)); - } else { - switch (type.value()) { - case DataType::DE_UINT8: { - // For strings (Assume DE_UINT8 is reserved for strings) - RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json)); - data = reinterpret_cast(common::SafeCStr(string_data)); - break; - } - case DataType::DE_FLOAT32: { - // For both float scalars and arrays - RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false)); - data = reinterpret_cast(array_data.get()); - break; - } - case DataType::DE_FLOAT64: { - // For both double scalars and arrays - RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true)); - data = reinterpret_cast(array_data.get()); - break; - } - default: { - // For both integers scalars and arrays - RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column)); - data = reinterpret_cast(array_data.get()); - break; - } - } - } - // Create Tensor with given details - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data)); - - return Status::OK(); -} - -Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data, - const std::vector &columns_blob, const int32_t pos, - const ColDescriptor &column) { - const auto kColumnSize = column.type().SizeInBytes(); - if (kColumnSize == 0) { - RETURN_STATUS_UNEXPECTED("column size is null"); - } - if (pos == -1) { - if (column.hasShape()) { - *new_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK( - column.MaterializeTensorShape(static_cast(columns_blob.size() / kColumnSize), new_shape)); - } else { - std::vector shapeDetails = {static_cast(columns_blob.size() / kColumnSize)}; - *new_shape = TensorShape(shapeDetails); - } - *data = reinterpret_cast(&(columns_blob[0])); - return Status::OK(); - } - auto uint64_from_bytes = [&](int64_t pos) { - uint64_t result = 0; - for (uint64_t n = 0; n < kInt64Len; n++) { - result = (result << 8) + columns_blob[pos + n]; - } - return result; - }; - uint64_t iStart = 0; - for (int32_t i = 0; i < pos; i++) { - uint64_t num_bytes = uint64_from_bytes(iStart); - iStart += kInt64Len + num_bytes; - } - uint64_t num_bytes = uint64_from_bytes(iStart); - iStart += kInt64Len; - if (column.hasShape()) { - *new_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_bytes / kColumnSize), new_shape)); - } else { - std::vector shapeDetails = {static_cast(num_bytes / kColumnSize)}; - *new_shape = TensorShape(shapeDetails); - } - *data = reinterpret_cast(&(columns_blob[iStart])); - return Status::OK(); -} - -template -Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr *array_data, const std::string &column_name, - const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) { - if (!columns_json[column_name].is_array()) { - T value = 0; - RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double)); - - *new_shape = TensorShape::CreateScalar(); - *array_data = std::make_unique(1); - (*array_data)[0] = value; - } else { - if (column.hasShape()) { - *new_shape = TensorShape(column.shape()); - } else { - std::vector shapeDetails = {static_cast(columns_json[column_name].size())}; - *new_shape = TensorShape(shapeDetails); - } - - int idx = 0; - *array_data = std::make_unique(new_shape->NumOfElements()); - for (auto &element : columns_json[column_name]) { - T value = 0; - RETURN_IF_NOT_OK(GetFloat(&value, element, use_double)); - - (*array_data)[idx++] = value; - } - } - - return Status::OK(); -} - -template -Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) { - if (data.is_number()) { - *value = data; - } else if (data.is_string()) { - try { - if (use_double) { - *value = data.get(); - } else { - *value = data.get(); - } - } catch (mindrecord::json::exception &e) { - RETURN_STATUS_UNEXPECTED("Conversion to float failed."); - } - } else { - RETURN_STATUS_UNEXPECTED("Conversion to float failed."); - } - - return Status::OK(); -} - -template -Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr *array_data, const std::string &column_name, - const mindrecord::json &columns_json, const ColDescriptor &column) { - if (!columns_json[column_name].is_array()) { - T value = 0; - RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name])); - - *new_shape = TensorShape::CreateScalar(); - *array_data = std::make_unique(1); - (*array_data)[0] = value; - } else { - if (column.hasShape()) { - *new_shape = TensorShape(column.shape()); - } else { - std::vector shapeDetails = {static_cast(columns_json[column_name].size())}; - *new_shape = TensorShape(shapeDetails); - } - - int idx = 0; - *array_data = std::make_unique(new_shape->NumOfElements()); - for (auto &element : columns_json[column_name]) { - T value = 0; - RETURN_IF_NOT_OK(GetInt(&value, element)); - - (*array_data)[idx++] = value; - } - } - - return Status::OK(); -} - -template -Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) { - int64_t temp_value = 0; - bool less_than_zero = false; - - if (data.is_number_integer()) { - const mindrecord::json json_zero = 0; - if (data < json_zero) less_than_zero = true; - temp_value = data; - } else if (data.is_string()) { - std::string string_value = data; - - if (!string_value.empty() && string_value[0] == '-') { - try { - temp_value = std::stoll(string_value); - less_than_zero = true; - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); - } - } else { - try { - temp_value = static_cast(std::stoull(string_value)); - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); - } - } - } else { - RETURN_STATUS_UNEXPECTED("Conversion to int failed."); - } - - if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || - (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range"); - } - *value = static_cast(temp_value); - - return Status::OK(); -} - -Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, - const mindrecord::json &columns_json) { - *string_data = columns_json[column_name]; - std::vector shape_details = {static_cast(string_data->size())}; - *new_shape = TensorShape(shape_details); - - return Status::OK(); -} - Status MindRecordOp::WorkerEntry(int32_t worker_id) { TaskManager::FindMe()->Post(); std::unique_ptr io_block; RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->eoe()) { RETURN_IF_NOT_OK( out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); continue; } - if (io_block->eof() == true) { + if (io_block->eof()) { RETURN_IF_NOT_OK( out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); @@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_bu if (tupled_buffer.empty()) break; } for (const auto &tupled_row : tupled_buffer) { - std::vector columnsBlob = std::get<0>(tupled_row); + std::vector columns_blob = std::get<0>(tupled_row); mindrecord::json columns_json = std::get<1>(tupled_row); TensorRow tensor_row; - for (uint32_t j = 0; j < columns_to_load_.size(); ++j) { - std::shared_ptr tensor; - - const ColDescriptor &cur_column = data_schema_->column(j); - DataType type = cur_column.type(); - RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json)); - - tensor_row.push_back(std::move(tensor)); - } - + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json)); tensor_table->push_back(std::move(tensor_row)); } } @@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_bu return Status::OK(); } -Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr *tensor, int32_t i_col, - const std::vector &columns_blob, - const mindrecord::json &columns_json) const { - switch (type.value()) { - case DataType::DE_BOOL: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); +Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json) { + for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { + auto column_name = columns_to_load_[i_col]; + + // Initialize column parameters + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; + uint64_t column_data_type_size = 1; + std::vector column_shape; + + // Get column data + + auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName( + column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size, + &column_shape); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); } - case DataType::DE_INT8: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_UINT8: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_INT16: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_UINT16: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_INT32: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_UINT32: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_INT64: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_UINT64: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_FLOAT32: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - case DataType::DE_FLOAT64: { - return LoadFeature(tensor, i_col, columns_blob, columns_json); - } - default: { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "mindrecord column list type does not match any known types"); + + std::shared_ptr tensor; + const ColDescriptor &column = data_schema_->column(i_col); + DataType type = column.type(); + + // Set shape + auto num_elements = n_bytes / column_data_type_size; + if (column.hasShape()) { + auto new_shape = TensorShape(column.shape()); + RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_elements), &new_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + } else { + std::vector shapeDetails = {static_cast(num_elements)}; + auto new_shape = TensorShape(shapeDetails); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); } + tensor_row->push_back(std::move(tensor)); } + return Status::OK(); } Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index c8e333d3736..251b4f91302 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,6 +32,7 @@ #include "dataset/engine/datasetops/source/io_block.h" #include "dataset/util/queue.h" #include "dataset/util/status.h" +#include "mindrecord/include/shard_column.h" #include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_reader.h" #include "mindrecord/include/common/shard_utils.h" @@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp { Status Init(); - Status SetColumnsBlob(); - // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. // @param modified - Whether this node visit modified the pipeline. @@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp { Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); // Parses a single cell and puts the data into a tensor - // @param tensor - the tensor to put the parsed data in - // @param i_col - the id of column to parse + // @param tensor_row - the tensor row to put the parsed data in // @param columns_blob - the blob data received from the reader // @param columns_json - the data for fields received from the reader - template - Status LoadFeature(std::shared_ptr *tensor, int32_t i_col, const std::vector &columns_blob, - const mindrecord::json &columns_json) const; - - Status SwitchLoadFeature(const DataType &type, std::shared_ptr *tensor, int32_t i_col, - const std::vector &columns_blob, const mindrecord::json &columns_json) const; - - static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector &columns_blob, - const int32_t pos, const ColDescriptor &column); - - // Get shape and data (scalar or array) for tensor to be created (for floats and doubles) - // @param new_shape - the shape of tensor to be created. - // @param array_data - the array where data should be put in - // @param column_name - name of current column to be processed - // @param columns_json - the data for fields received from the reader - // @param column - description of current column from schema - // @param use_double - boolean to choose between float32 and float64 - template - static Status LoadFloat(TensorShape *new_shape, std::unique_ptr *array_data, const std::string &column_name, - const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double); - - // Get shape and data (scalar or array) for tensor to be created (for integers) - // @param new_shape - the shape of tensor to be created. - // @param array_data - the array where data should be put in - // @param column_name - name of current column to be processed - // @param columns_json - the data for fields received from the reader - // @param column - description of current column from schema - template - static Status LoadInt(TensorShape *new_shape, std::unique_ptr *array_data, const std::string &column_name, - const mindrecord::json &columns_json, const ColDescriptor &column); - - static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, - const mindrecord::json &columns_json); - - // Get a single float value from the given json - // @param value - the float to put the value in - // @param arrayData - the given json containing the float - // @param use_double - boolean to choose between float32 and float64 - template - static Status GetFloat(T *value, const mindrecord::json &data, bool use_double); - - // Get a single integer value from the given json - // @param value - the integer to put the value in - // @param arrayData - the given json containing the integer - template - static Status GetInt(T *value, const mindrecord::json &data); + Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json); Status FetchBlockBuffer(const int32_t &buffer_id); diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 0391ee5e199..ee923ebc977 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) { .def("launch", &ShardReader::Launch) .def("get_header", &ShardReader::GetShardHeader) .def("get_blob_fields", &ShardReader::GetBlobFields) - .def("get_next", - (std::vector, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) + .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & + ShardReader::GetNextPy) .def("finish", &ShardReader::Finish) .def("close", &ShardReader::Close); } diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index 3af4d7f8913..65a8d53e72c 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4; enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; +const char kVersion[] = "3.0"; +const std::vector kSupportedVersion = {"2.0", kVersion}; + enum ShardType { kNLP = 0, kCV = 1, diff --git a/mindspore/ccsrc/mindrecord/include/shard_column.h b/mindspore/ccsrc/mindrecord/include/shard_column.h new file mode 100644 index 00000000000..e327ef511a9 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_column.h @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ + +#include +#include +#include +#include +#include +#include "mindrecord/include/shard_header.h" + +namespace mindspore { +namespace mindrecord { +const uint64_t kUnsignedOne = 1; +const uint64_t kBitsOfByte = 8; +const uint64_t kDataTypeBits = 2; +const uint64_t kNumDataOfByte = 4; +const uint64_t kBytesOfColumnLen = 4; +const uint64_t kDataTypeBitMask = 3; +const uint64_t kDataTypes = 6; + +enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; + +enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; + +enum ColumnDataType { + ColumnBytes = 0, + ColumnString = 1, + ColumnInt32 = 2, + ColumnInt64 = 3, + ColumnFloat32 = 4, + ColumnFloat64 = 5, + ColumnNoDataType = 6 +}; + +// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; +const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; + +const std::vector ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32", + "int64", "float32", "float64"}; + +const std::unordered_map ColumnDataTypeMap = { + {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, + {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; + +class ShardColumn { + public: + explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); + + ~ShardColumn() = default; + + /// \brief get column value by column name + MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape); + + /// \brief compress blob + std::vector CompressBlob(const std::vector &blob); + + /// \brief check if blob compressed + bool CheckCompressBlob() const { return has_compress_blob_; } + + uint64_t GetNumBlobColumn() const { return num_blob_column_; } + + std::vector GetColumnName() { return column_name_; } + + std::vector GeColumnDataType() { return column_data_type_; } + + std::vector> GetColumnShape() { return column_shape_; } + + /// \brief get column value from blob + MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *n_bytes); + + private: + /// \brief get column value from json + MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes); + + /// \brief get float value from json + template + MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); + + /// \brief get integer value from json + template + MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); + + /// \brief get column offset address and size from blob + MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx); + + /// \brief check if column name is available + ColumnCategory CheckColumnName(const std::string &column_name); + + /// \brief compress integer column + static vector CompressInt(const vector &src_bytes, const IntegerType &int_type); + + /// \brief uncompress integer array column + template + static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); + + /// \brief convert big-endian bytes to unsigned int + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param i_type integer type + /// \return unsigned int + static uint64_t BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type); + + /// \brief convert unsigned int to big-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesBig(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesLittle(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param src_i_type source integer typ0e + /// \param dst_i_type (output), destination integer type + /// \return integer + static int64_t BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); + + private: + std::vector column_name_; // column name list + std::vector column_data_type_; // column data type list + std::vector> column_shape_; // column shape list + std::unordered_map column_name_id_; // column name id map + std::vector blob_column_; // blob column list + std::unordered_map blob_column_id_; // blob column name id map + bool has_compress_blob_; // if has compress blob + uint64_t num_blob_column_; // number of blob columns +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index 0f2473e910d..e4361c466a8 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -118,8 +118,6 @@ class ShardHeader { void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - const string GetVersion() { return version_; } - std::vector SerializeHeader(); MSRStatus PagesToFile(const std::string dump_file_name); @@ -175,7 +173,6 @@ class ShardHeader { uint32_t shard_count_; uint64_t header_size_; uint64_t page_size_; - string version_ = "2.0"; std::shared_ptr index_; std::vector shard_addresses_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 13d68b01f71..d1a427af276 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -43,6 +43,7 @@ #include #include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/shard_category.h" +#include "mindrecord/include/shard_column.h" #include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_index_generator.h" #include "mindrecord/include/shard_operator.h" @@ -111,6 +112,10 @@ class ShardReader { /// \return the metadata std::shared_ptr GetShardHeader() const; + /// \brief aim to get columns context + /// \return the columns + std::shared_ptr get_shard_column() const; + /// \brief get the number of shards /// \return # of shards int GetShardCount() const; @@ -185,7 +190,7 @@ class ShardReader { /// \brief return a batch, given that one is ready, python API /// \return a batch of images and image data - std::vector, pybind11::object>> GetNextPy(); + std::vector>, pybind11::object>> GetNextPy(); /// \brief get blob filed list /// \return blob field list @@ -295,16 +300,18 @@ class ShardReader { /// \brief get number of classes int64_t GetNumClasses(const std::string &category_field); + /// \brief get meta of header std::pair> GetMeta(const std::string &file_path, json &meta_data); - /// \brief get exactly blob fields data by indices - std::vector ExtractBlobFieldBySelectColumns(std::vector &blob_fields_bytes, - std::vector &ordered_selected_columns_index); + + /// \brief extract uncompressed data based on column list + std::pair>> UnCompressBlob(const std::vector &raw_blob_data); protected: uint64_t header_size_; // header size uint64_t page_size_; // page size int shard_count_; // number of shards std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard column std::vector database_paths_; // sqlite handle list std::vector file_paths_; // file paths diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 4679814287e..6175180c927 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -36,6 +36,7 @@ #include #include #include "mindrecord/include/common/shard_utils.h" +#include "mindrecord/include/shard_column.h" #include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_header.h" #include "mindrecord/include/shard_index.h" @@ -242,7 +243,8 @@ class ShardWriter { std::vector file_paths_; // file paths std::vector> file_streams_; // file handles - std::shared_ptr shard_header_; // shard headers + std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard columns std::map> err_mg_; // used for storing error raw_data info diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 1f0a6b8dce2..7b3e222c9e8 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa shard_header_ = std::make_shared(sh); header_size_ = shard_header_->GetHeaderSize(); page_size_ = shard_header_->GetPageSize(); + // version < 3.0 + if (first_meta_data["version"] < kVersion) { + shard_column_ = std::make_shared(shard_header_, false); + } else { + shard_column_ = std::make_shared(shard_header_, true); + } num_rows_ = 0; auto row_group_summary = ReadRowGroupSummary(); for (const auto &rg : row_group_summary) { @@ -226,6 +232,8 @@ void ShardReader::Close() { std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } +std::shared_ptr ShardReader::get_shard_column() const { return shard_column_; } + int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } int ShardReader::GetNumRows() const { return num_rows_; } @@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector ShardReader::ExtractBlobFieldBySelectColumns( - std::vector &blob_fields_bytes, std::vector &ordered_selected_columns_index) { - std::vector exactly_blob_fields_bytes; - - auto uint64_from_bytes = [&](int64_t pos) { - uint64_t result = 0; - for (uint64_t n = 0; n < kInt64Len; n++) { - result = (result << 8) + blob_fields_bytes[pos + n]; - } - return result; - }; - - // get the exactly blob fields - uint32_t current_index = 0; - uint64_t current_offset = 0; - uint64_t data_len = uint64_from_bytes(current_offset); - while (current_offset < blob_fields_bytes.size()) { - if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(), - [¤t_index](uint32_t &index) { return index == current_index; })) { - exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset, - blob_fields_bytes.begin() + current_offset + kInt64Len + data_len); - } - current_index++; - current_offset += kInt64Len + data_len; - data_len = uint64_from_bytes(current_offset); - } - - return exactly_blob_fields_bytes; -} - TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { // All tasks are done if (task_id >= static_cast(tasks_.Size())) { @@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ return std::make_pair(FAILED, std::vector, json>>()); } - // extract the exactly blob bytes by selected columns - std::vector images_with_exact_columns; - if (selected_columns_.size() == 0) { - images_with_exact_columns = images; - } else { - auto blob_fields = GetBlobFields(); - - std::vector ordered_selected_columns_index; - uint32_t index = 0; - for (auto &blob_field : blob_fields.second) { - for (auto &field : selected_columns_) { - if (field.compare(blob_field) == 0) { - ordered_selected_columns_index.push_back(index); - break; - } - } - index++; - } - - if (ordered_selected_columns_index.size() != 0) { - // extract the images - if (blob_fields.second.size() == 1) { - if (ordered_selected_columns_index.size() == 1) { - images_with_exact_columns = images; - } - } else { - images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index); - } - } - } - // Deliver batch data to output map std::vector, json>> batch; - batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); + batch.emplace_back(std::move(images), std::move(std::get<2>(task))); + return std::make_pair(SUCCESS, std::move(batch)); } @@ -1369,16 +1317,41 @@ std::vector, json>> ShardReader::GetNextById(con return std::move(ret.second); } -std::vector, pybind11::object>> ShardReader::GetNextPy() { +std::pair>> ShardReader::UnCompressBlob( + const std::vector &raw_blob_data) { + auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; + auto blob_fields = GetBlobFields().second; + std::vector> blob_data; + for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { + if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; + return {FAILED, std::vector>(blob_fields.size(), std::vector())}; + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + std::vector column(data, data + (n_bytes / sizeof(unsigned char))); + blob_data.push_back(column); + } + return {SUCCESS, blob_data}; +} + +std::vector>, pybind11::object>> ShardReader::GetNextPy() { auto res = GetNext(); - vector, pybind11::object>> jsonData; - std::transform(res.begin(), res.end(), std::back_inserter(jsonData), - [](const std::tuple, json> &item) { + vector>, pybind11::object>> data; + std::transform(res.begin(), res.end(), std::back_inserter(data), + [this](const std::tuple, json> &item) { auto &j = std::get<1>(item); pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); + auto ret = UnCompressBlob(std::get<0>(item)); + return std::make_tuple(ret.second, std::move(obj)); }); - return jsonData; + return data; } void ShardReader::Reset() { diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 43967c43c54..0b0acf52d7c 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { MS_LOG(ERROR) << "Open file failed"; return FAILED; } + shard_column_ = std::make_shared(shard_header_); return SUCCESS; } @@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) shard_header_ = header_data; shard_header_->SetHeaderSize(header_size_); shard_header_->SetPageSize(page_size_); + shard_column_ = std::make_shared(shard_header_); return SUCCESS; } @@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } + + // compress blob + if (shard_column_->CheckCompressBlob()) { + for (auto &blob : blob_data) { + blob = shard_column_->CompressBlob(blob); + } + } + // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/mindrecord/meta/shard_column.cc new file mode 100644 index 00000000000..86ad0c96d7b --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_column.cc @@ -0,0 +1,473 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_column.h" + +#include "common/utils.h" +#include "mindrecord/include/common/shard_utils.h" +#include "mindrecord/include/shard_error.h" + +namespace mindspore { +namespace mindrecord { +ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { + auto first_schema = shard_header->GetSchemas()[0]; + auto schema = first_schema->GetSchema()["schema"]; + + bool has_integer_array = false; + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + const std::string &column_name = it.key(); + column_name_.push_back(column_name); + + json it_value = it.value(); + + std::string str_type = it_value["type"]; + column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); + if (it_value.find("shape") != it_value.end()) { + std::vector vec(it_value["shape"].size()); + std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); + column_shape_.push_back(vec); + if (str_type == "int32" || str_type == "int64") { + has_integer_array = true; + } + } else { + std::vector vec = {}; + column_shape_.push_back(vec); + } + } + + for (uint64_t i = 0; i < column_name_.size(); i++) { + column_name_id_[column_name_[i]] = i; + } + + auto blob_fields = first_schema->GetBlobFields(); + + for (const auto &field : blob_fields) { + blob_column_.push_back(field); + } + + for (uint64_t i = 0; i < blob_column_.size(); i++) { + blob_column_id_[blob_column_[i]] = i; + } + + has_compress_blob_ = (compress_integer && has_integer_array); + num_blob_column_ = blob_column_.size(); +} + +MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return FAILED; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + // Retrieve value from json + if (column_category == ColumnInRaw) { + if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; + return FAILED; + } + *data = reinterpret_cast(data_ptr->get()); + return SUCCESS; + } + + // Retrieve value from blob + if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; + return FAILED; + } + if (*data == nullptr) { + *data = reinterpret_cast(data_ptr->get()); + } + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes) { + auto column_id = column_name_id_[column_name]; + auto column_data_type = column_data_type_[column_id]; + + // Initialize num bytes + *n_bytes = ColumnDataTypeSize[column_data_type]; + auto json_column_value = columns_json[column_name]; + switch (column_data_type) { + case ColumnFloat32: { + return GetFloat(data_ptr, json_column_value, false); + } + case ColumnFloat64: { + return GetFloat(data_ptr, json_column_value, true); + } + case ColumnInt32: { + return GetInt(data_ptr, json_column_value); + } + case ColumnInt64: { + return GetInt(data_ptr, json_column_value); + } + default: { + // Convert string to c_str + std::string tmp_string = json_column_value; + *n_bytes = tmp_string.size(); + auto data = reinterpret_cast(common::SafeCStr(tmp_string)); + *data_ptr = std::make_unique(*n_bytes); + for (uint32_t i = 0; i < *n_bytes; i++) { + (*data_ptr)[i] = *(data + i); + } + break; + } + } + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, + bool use_double) { + std::unique_ptr array_data = std::make_unique(1); + if (!json_column_value.is_string() && !json_column_value.is_number()) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + if (json_column_value.is_number()) { + array_data[0] = json_column_value; + } else { + // Convert string to float + try { + if (use_double) { + array_data[0] = json_column_value.get(); + } else { + array_data[0] = json_column_value.get(); + } + } catch (json::exception &e) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { + std::unique_ptr array_data = std::make_unique(1); + int64_t temp_value; + bool less_than_zero = false; + + if (json_column_value.is_number_integer()) { + const json json_zero = 0; + if (json_column_value < json_zero) less_than_zero = true; + temp_value = json_column_value; + } else if (json_column_value.is_string()) { + std::string string_value = json_column_value; + + if (!string_value.empty() && string_value[0] == '-') { + try { + temp_value = std::stoll(string_value); + less_than_zero = true; + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } else { + try { + temp_value = static_cast(std::stoull(string_value)); + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } + } else { + MS_LOG(ERROR) << "Conversion to int failed."; + return FAILED; + } + + if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || + (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { + MS_LOG(ERROR) << "Conversion to int failed. Out of range"; + return FAILED; + } + array_data[0] = static_cast(temp_value); + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *n_bytes) { + uint64_t offset_address = 0; + auto column_id = column_name_id_[column_name]; + if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { + return FAILED; + } + + auto column_data_type = column_data_type_[column_id]; + if (has_compress_blob_ && column_data_type == ColumnInt32) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else if (has_compress_blob_ && column_data_type == ColumnInt64) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else { + *data = reinterpret_cast(&(columns_blob[offset_address])); + } + + return SUCCESS; +} + +ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { + auto it_column = column_name_id_.find(column_name); + if (it_column == column_name_id_.end()) { + return ColumnNotFound; + } + auto it_blob = blob_column_id_.find(column_name); + return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; +} + +std::vector ShardColumn::CompressBlob(const std::vector &blob) { + // Skip if no compress columns + if (!CheckCompressBlob()) return blob; + + std::vector dst_blob; + uint64_t i_src = 0; + for (int64_t i = 0; i < num_blob_column_; i++) { + // Get column data type + auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; + auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; + + // Compress and return is blob has 1 column only + if (num_blob_column_ == 1) { + return CompressInt(blob, int_type); + } + + // Just copy and continue if column dat type is not int32/int64 + uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); + if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { + dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); + i_src += kInt64Len + num_bytes; + continue; + } + + // Get column slice in source blob + std::vector blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); + // Compress column + auto dst_blob_slice = CompressInt(blob_slice, int_type); + // Get new column size + auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); + // Append new colmn size + dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); + // Append new colmn data + dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); + i_src += kInt64Len + num_bytes; + } + MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; + return dst_blob; +} + +vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { + uint64_t i_size = kUnsignedOne << int_type; + // Get number of elements + uint64_t src_n_int = src_bytes.size() / i_size; + // Calculate bitmap size (bytes) + uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; + + // Initilize destination blob, more space than needed, will be resized + vector dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); + + // Write number of elements to destination blob + vector size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); + for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { + dst_bytes[n] = size_by_bytes[n]; + } + + // Write compressed int + uint64_t i_dst = kBytesOfColumnLen + bitmap_size; + for (uint64_t i = 0; i < src_n_int; i++) { + // Initialize destination data type + IntegerType dst_int_type = kInt8Type; + // Shift to next int position + uint64_t pos = i * (kUnsignedOne << int_type); + // Narrow down this int + int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); + + // Write this int to destination blob + uint64_t u_n = *reinterpret_cast(&i_n); + auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); + for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) { + dst_bytes[i_dst++] = temp_bytes[j]; + } + + // Update date type in bit map + dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= + (dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); + } + // Resize destination blob + dst_bytes.resize(i_dst); + MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; + return dst_bytes; +} + +MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx) { + if (num_blob_column_ == 1) { + *num_bytes = columns_blob.size(); + *shift_idx = 0; + return SUCCESS; + } + auto blob_id = blob_column_id_[column_name_[column_id]]; + + for (int32_t i = 0; i < blob_id; i++) { + *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + } + *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + + (*shift_idx) += kInt64Len; + + return SUCCESS; +} + +template +MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, + uint64_t shift_idx) { + auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); + *num_bytes = sizeof(T) * num_elements; + + // Parse integer array + uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; + auto array_data = std::make_unique(num_elements); + + for (uint64_t i = 0; i < num_elements; i++) { + uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; + uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; + auto mr_int_type = static_cast(i_type); + int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); + i_source += (kUnsignedOne << i_type); + array_data[i] = static_cast(i64); + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(*num_bytes); + memcpy(data_ptr->get(), data, *num_bytes); + + return SUCCESS; +} + +uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type) { + uint64_t result = 0; + for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) { + result = (result << kBitsOfByte) + bytes_array[pos + i]; + } + return result; +} + +std::vector ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << i_type; + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[n_bytes - 1 - i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +std::vector ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << i_type; + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +int64_t ShardColumn::BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type) { + uint64_t u_temp = 0; + for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) { + u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i]; + } + + int64_t i_out; + switch (src_i_type) { + case kInt8Type: { + i_out = (int8_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt16Type: { + i_out = (int16_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt32Type: { + i_out = (int32_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt64Type: { + i_out = (int64_t)(u_temp & std::numeric_limits::max()); + break; + } + default: { + i_out = 0; + } + } + + if (!dst_i_type) { + return i_out; + } + + if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt8Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt16Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt32Type; + } else { + *dst_i_type = kInt64Type; + } + return i_out; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 3adb017352f..ec177394ef3 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &heade json header; header = ret.second; header["shard_addresses"] = realAddresses; - if (header["version"] != version_) { + if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() - << ", lib version is: " << version_; + << ", lib version is: " << kVersion; thread_status = true; return; } @@ -339,7 +339,7 @@ std::vector ShardHeader::SerializeHeader() { s += "\"shard_addresses\":" + address + ","; s += "\"shard_id\":" + std::to_string(shardId) + ","; s += "\"statistics\":" + stats + ","; - s += "\"version\":\"" + version_ + "\""; + s += "\"version\":\"" + std::string(kVersion) + "\""; s += "}"; header.emplace_back(s); } diff --git a/mindspore/mindrecord/shardutils.py b/mindspore/mindrecord/shardutils.py index a71dd228f64..31be5382e89 100644 --- a/mindspore/mindrecord/shardutils.py +++ b/mindspore/mindrecord/shardutils.py @@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema): if not blob_fields: return raw - # Get the order preserving sequence of columns in blob - ordered_columns = [] + loaded_columns = [] if columns: - for blob_field in blob_fields: - if blob_field in columns: - ordered_columns.append(blob_field) + for column in columns: + if column in blob_fields: + loaded_columns.append(column) else: - ordered_columns = blob_fields - - blob_bytes = bytes(blob) + loaded_columns = blob_fields def _render_raw(field, blob_data): data_type = schema[field]['type'] @@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): else: raw[field] = blob_data - if len(blob_fields) == 1: - if len(ordered_columns) == 1: - _render_raw(blob_fields[0], blob_bytes) - return raw - return raw - - def _int_from_bytes(xbytes: bytes) -> int: - return int.from_bytes(xbytes, 'big') - - def _blob_at_position(pos): - start = 0 - for _ in range(pos): - n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) - start += 8 + n_bytes - n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) - start += 8 - return blob_bytes[start : start + n_bytes] - - for i, blob_field in enumerate(ordered_columns): - _render_raw(blob_field, _blob_at_position(i)) + for i, blob_field in enumerate(loaded_columns): + _render_raw(blob_field, bytes(blob[i])) return raw diff --git a/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0 b/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0 new file mode 100644 index 0000000000000000000000000000000000000000..341d3c4eec805707e9522ffb7e12e22426127936 GIT binary patch literal 49216 zcmeI(&ui2`6u|MVy%arr^HAm-%l;^9S-foVQV;|Y1Q!XD&9n_27-w6ffU7>`dOgc{BYabL`oEt85E%>1^z>Fi*ZZGwAetC&z(Frm_1H zo+mDib2At%?gq@jq&9M?2~0k|bQ2pUu^F_N9|JSD7g9sg-82rf+2?Xb%4xjH@?*bQ zr+?xD#r0fPn$y>BPQzEjcS9*&F4B0tQVU*Wv#YrT{Hb)~z|79ib65B?m_+% zvto0S_A=N$nylR%t+REsetl|xu(!ka|MFbk&cB3S_l3e#x}|57&Z6dZ7@KAGI<$gR zjiSwHycxwtipDM<3+ptUhrU`kyZY0PNqe#YnII>ck(*HPGuO5)b1d}>tKP1S9j9lLI69ihN1M6v zEQ{NUa4*_136r>Ds8Q$ z_4WtyZ}5k)Eo_<7Y90Qn=5KF&|JYB|3oycn?`Zor>V4Ad^0)7rzj|~CAb=J zcj=dCKSB>ZmtIOK6bij`cC%^>A-$Cbo)_1d9UjNmAA51d-m~uIEK(Pf>9{{rJETM? zC3{K{LU?wVAuVQ}leKx#`Q(m~lkMG4Mfr>4_dNqluK%WK5^~`b7?bhC` z(Y11=Lf=N|c8=T9XS$aSY-OWtv8t5sorX?Jt(M~kYO!x!9oubfLZNS;E_1hKyn}~*3TmMmrUBwYQ}N*!$#6$UB&%D z6kna_w)PE+?;P3wxjNM6)hot&8n4x!uw2=$P;&Wv6uo}=A~xUZ!~Uz;Bs~u^{P1Zu zvCO3=lFO}V4ldsl`CWdNAIt;=0SG_<0uX=z1Rwwb2tWV=5P-n{Ah5v;G*77zj0tKr z_N;qj`I4ertg}uFA}`E~1rwq~w;q~L#N`(vf6Jfpo0*^> z009U<00Izz00bZa0SG_<0uZ>b0y$32y#k&N%zD7hEda*35c%Zye=5Hc^FToW0uX=z p1Rwwb2tWV=5P$##AaLIW9#gi{AI8pjG}s=SKLt#q;bb~$`~k+6smlNW literal 0 HcmV?d00001 diff --git a/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1 b/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1 new file mode 100644 index 0000000000000000000000000000000000000000..4dcbd9119225ce48228a40ed730aa4c1d2ef91e5 GIT binary patch literal 49216 zcmeI(&ui2`6u|Kne-%A@^HAm-%l;^9S-iA(5d=X5!6IR@Gi?Lu?vkWxU6$Sc7ybj@ z{oe$y^}S@WX}7%yikI&kb|!D$yqSKIIri1vPSxh-#9Qa{xJbWxGwKZod%KZIC(eI~ zk5Zqx!i)~*R|95bGMo6!M5Y*@_){CF&WyT?kCB<#W2qtOe&XVM`nj5savHC*!Z>W! z8}5Zbc|Mbs=HT_4gZSnCyL~BMEz)?iQVSmE)3ccb!l`uQ$V`up3SWjZ6el)!)smhZ zu(Wx(K1p{GT|b(vT^~)dt8sm_VRPzmu($i~|K+)YU3dwD{&R(?bt}&#n{n6;!v$EJ71R>&5fw(3pjk%lU3nwPABD{Zx-z)SMqw% zM&@w0&6F+YO_Aih^hG-*FRb0He|wSo>uO7cb1d`=tKY75F84(tg-3_Rj2xBz{hh8^)Ypk$1Q0*~ z0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY** z5I_I{1Q0*~0R#|0009ILK;VxBp5D2(6aRizM9maOkPT`qmy-- zrfd6FcXg_cb*Z){L6>V=g~Vm`BtPM~+C#@pVpypAX$yPa$dIX}&dd2%I@$U0)%iq3h;p)*LfB*sr zAb=&;g0|S literal 0 HcmV?d00001 diff --git a/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db b/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db new file mode 100644 index 0000000000000000000000000000000000000000..88f8e69afdbb2bae85a85e869993889309b467b7 GIT binary patch literal 16384 zcmeI(O={af7zW^((ZrT(+%XNs3vVWng$-_S$N>^Lp4cLal*$T8R>D6j1mv1J4s>6- z?^SwP+>rfvubaKI~u3Ch2~d;fLGV z#L$aoLi0|UF_@E~!2tWV=5P$##AOHafKmY;|fB*#kzrZ>#(gLNX$(Tuv z=AIGrclmYW#v+qV)-~eX6ILuNl4VIZS$9p)qFFHY(WXASOgH%>{f)TzLd0+JQ+(4C z6a*ju0SG_<0uX=z1Rwwb2tWV=4^<$?seW3((}5XlTt5R~oSSAL`Td`YuSCC45P$## rAOHafKmY;|fB*y_009U*bb+Uo?evGSJsu6V#`;eI(`Y!Ej<){*pv|er literal 0 HcmV?d00001 diff --git a/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2 b/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2 new file mode 100644 index 0000000000000000000000000000000000000000..dc34999f1799e6ae3acd3bd5959f2bb158802d62 GIT binary patch literal 49216 zcmeI(&ubGw6u|L_m)f4ac_?#UOMaACf|nLA1wjx&u*fo-oi+>GB#mjdNnb|jQ-b_EcbL{!!R@r7|X<;Q-r zPXELQircBIG-t2hoQ1DW-5Di{~9#<^SenT=cqT-fVI&uM=!# z_K%xP*>YAFiEay*H&gP$n$5bqi&S4%Q^KEPu3uR7c5NJIF3+X#;KJBpdVU>8N7wSv zW^Opi;(l4c7Xh8aMQiC$Nl3+_Gt29m5 z_O0rwR2}P5ZFPb!*S7SDi}Fc+{Bya7j_X9fQ1{d3_k55cQwyD!^DlL>_2cQmhhFce z#-!ET@5#TxAI7$@WlpPg_^X;<-T8jsPt*%A!ieu^^ET>z((CfK@0!1QbO<1T00Iag zfB*srAbKoe!AB4uK)g{91i>3mvb0^tn0LzH`=jYOIee4;`XM(-Y453fF^lE1$#mSG$sJNA zl#)Fu2_Zb+^pF+P@9EmS=zMa^$jSEZr;_+ZjM5Kse^-3e0}28VfB*y_009U<00Izz z00jPtz@bs7Y;MxwtUpNNUe^xVk!K&N<&5tK$_|wb?Ut+La=n(uWT?(U>G^s)aozfx zHM&-*R_WU~JDuaT)tTz$16w&!zF3no@J}PBEmzC&Lb*6^Lmt~5r9)9*pDuGZW&DGK zo(h+_+cS~!^ao$dr2TXP|K#{;SJyK0ep;^IS~PT1ChupF`%5NkXf@+_`%yFPu_2TG zAWklKR7VAxC2)@H;9MT6^V%h2Ba7GTk658%S1Gx8K8jzze39t8`l$bEHp%v*96x%T zPYivjnd#+L(i<1=iTEx)iw}B&f&c^{009U<00Izz00bZa0SG|g{|jvJBDE+rO~y=W zH1~{zzbkGSR~DIUvaXS=KVU^`kt|EP#k!WDMbk3%(WXASOt;K+{f)TzLd0+JQ+(4C z6a*ju0SG_<0uX=z1Rwwb2tWV=cU7RkseW3(vw<1+xPAt}I5$lz{r#VcuSCC45P$## rAOHafKmY;|fB*y_009Wxb%BSJ?evF>OCV4vg+fCSC+m_36x&Ej+te64|3&|S z?)~2iUDNlZlNCEz3N2l}H~3EP-o3kgq#JX1zf-lDIde8~c~qodof-B9;ofdw(sAOx zL?@}slEMs+XEy_8XfhkS%mk(wow;)xrHL7K=N|(zwWm@;(%m?T^2z6FM#^ct&hlfw zS#PlC1LfsZR+@vOHwV${{dfCPyjrC3W~CN9%_kRA3HVd##(|lfoD{C~XDH5Wo>WUh zIbdn?a&?mKJXk-PtX&>o}W6zFJ#$wkk868Bx<0aQvbttNh=bj?1ub=Itgo@_NBW z=3uwYlr3jXk@&K7MLQ)gtlg}Ccai$*YD@TY%=HVa->yxP+!cis9v&Mz$_~$yc;{R` z+T4vMdD2ycThZQ~(f*y$Q1Lduq%tkabR>I@FS0B!pIly~ldUM>Qo);Qf*CwF4wm7iHqtzfkwn=J&jnAyW&Tm-8=mvh(BF!w+Hj zQ)AMa?YHFL;I|W7+A62h+5T0}UtjtDfuCp=V1yCh(fV!F`=r<9Z{Ib4_2>{l009IL zKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~ z0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY** z5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5cn?wTYlH8@0pEV Q_Gshz@y7F$jpwJo0Vw6+wEzGB literal 0 HcmV?d00001 diff --git a/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db b/tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db new file mode 100644 index 0000000000000000000000000000000000000000..f04f73b7270d54764efb319133c9f062c1550c3b GIT binary patch literal 16384 zcmeI(O={af7zW^((bSRaI%67&7v4-D3me?v;sYddJh4R-DU}tHtb~752*@>c9O%Au z->dWzg&v{w0HvE23WY*99VxOMhmh`C0`G_A84a)G&kx#2#=Ymx`7DwzCev|$CYz*0 zC?&g65<+;k=^-tq-;=d@(fQ<#k(2G6Pet*I7{woCeMfxM0}28VfB*y_009U<00Izz z00jPtz=4q~Z*J1ztUrjOUf1&5p=%wg<&5w7$_kVWtd^tXa=n_wWS~w1>3VuQcAVOq zHM&-=ROs6%J)P~g)v4-b16w&!wpf+Y_fA5)EmzBS1GzYFT^?B-r9+`_oh);=WxW0U zo(h(^yECD3^#|X`B>i-J@A&9?S2r^Aep-&#S~PT9ChKRB`%5NmXfC%^ww@s;Qo3IY&- r00bZa0SG_<0uX=z1Rwx``!4W=vSxo6+vCw-d#wKyFpY+j>8SY!q^qgv literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 57c19fbd80b..a882dc6bcb1 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -35,6 +35,7 @@ CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" CV_DIR_NAME = "../data/mindrecord/testImageNetData" NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" +OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord" NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt" @@ -46,7 +47,8 @@ def add_and_remove_cv_file(): for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None - os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists( + "{}.db".format(x)) else None writer = FileWriter(CV_FILE_NAME, FILES_NUM) data = get_data(CV_DIR_NAME) cv_schema_json = {"id": {"type": "int32"}, @@ -96,13 +98,105 @@ def add_and_remove_nlp_file(): os.remove("{}.db".format(x)) +@pytest.fixture +def add_and_remove_nlp_compress_file(): + """add/remove nlp file""" + paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(NLP_FILE_NAME, FILES_NUM) + data = [] + for row_id in range(16): + data.append({ + "label": row_id, + "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, + 255, 256, -32768, 32767, -32769, 32768, -2147483648, + 2147483647], dtype=np.int32), [-1]), + "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, + 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), + "array_c": str.encode("nlp data"), + "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) + }) + nlp_schema_json = {"label": {"type": "int32"}, + "array_a": {"type": "int32", + "shape": [-1]}, + "array_b": {"type": "int64", + "shape": [1, -1]}, + "array_c": {"type": "bytes"}, + "array_d": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.write_raw_data(data) + writer.commit() + yield "yield_nlp_data" + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + +def test_nlp_compress_data(add_and_remove_nlp_compress_file): + """tutorial for nlp minderdataset.""" + data = [] + for row_id in range(16): + data.append({ + "label": row_id, + "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, + 255, 256, -32768, 32767, -32769, 32768, -2147483648, + 2147483647], dtype=np.int32), [-1]), + "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, + 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), + "array_c": str.encode("nlp data"), + "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) + }) + num_readers = 1 + data_set = ds.MindDataset( + NLP_FILE_NAME + "0", None, num_readers, shuffle=False) + assert data_set.get_dataset_size() == 16 + num_iter = 0 + for x, item in zip(data, data_set.create_dict_iterator()): + assert (item["array_a"] == x["array_a"]).all() + assert (item["array_b"] == x["array_b"]).all() + assert item["array_c"].tobytes() == x["array_c"] + assert (item["array_d"] == x["array_d"]).all() + assert item["label"] == x["label"] + num_iter += 1 + assert num_iter == 16 + + +def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file): + """tutorial for nlp minderdataset.""" + num_readers = 1 + data_set = ds.MindDataset( + NLP_FILE_NAME + "0", None, num_readers, shuffle=False) + old_data_set = ds.MindDataset( + OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False) + assert old_data_set.get_dataset_size() == 16 + num_iter = 0 + for x, item in zip(old_data_set.create_dict_iterator(), data_set.create_dict_iterator()): + assert (item["array_a"] == x["array_a"]).all() + assert (item["array_b"] == x["array_b"]).all() + assert (item["array_c"] == x["array_c"]).all() + assert (item["array_d"] == x["array_d"]).all() + assert item["label"] == x["label"] + num_iter += 1 + assert num_iter == 16 + + def test_cv_minddataset_writer_tutorial(): """tutorial for cv dataset writer.""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None - os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists( + "{}.db".format(x)) else None writer = FileWriter(CV_FILE_NAME, FILES_NUM) data = get_data(CV_DIR_NAME) cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, @@ -127,8 +221,10 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): num_shards=num_shards, shard_id=partition_id) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- partition : {} ------------------------".format(partition_id)) - logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + logger.info( + "-------------- partition : {} ------------------------".format(partition_id)) + logger.info( + "-------------- item[label]: {} -----------------------".format(item["label"])) num_iter += 1 return num_iter @@ -147,9 +243,12 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): data_set = data_set.repeat(repeat_num) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- get dataset size {} -----------------".format(num_iter)) - logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) - logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) + logger.info( + "-------------- get dataset size {} -----------------".format(num_iter)) + logger.info( + "-------------- item[label]: {} ---------------------".format(item["label"])) + logger.info( + "-------------- item[data]: {} ----------------------".format(item["data"])) num_iter += 1 assert num_iter == 20 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, @@ -163,17 +262,22 @@ def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): num_readers = 4 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) decode_op = vision.Decode() - data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) + data_set = data_set.map( + input_columns=["data"], operations=decode_op, num_parallel_workers=2) resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) - data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) + data_set = data_set.map(input_columns="data", + operations=resize_op, num_parallel_workers=2) data_set = data_set.batch(2) data_set = data_set.repeat(2) num_iter = 0 labels = [] for item in data_set.create_dict_iterator(): - logger.info("-------------- get dataset size {} -----------------".format(num_iter)) - logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) - logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) + logger.info( + "-------------- get dataset size {} -----------------".format(num_iter)) + logger.info( + "-------------- item[label]: {} ---------------------".format(item["label"])) + logger.info( + "-------------- item[data]: {} ----------------------".format(item["data"])) num_iter += 1 labels.append(item["label"]) assert num_iter == 10 @@ -189,15 +293,20 @@ def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): num_readers = 4 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) decode_op = vision.Decode() - data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) + data_set = data_set.map( + input_columns=["data"], operations=decode_op, num_parallel_workers=2) resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) - data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) + data_set = data_set.map(input_columns="data", + operations=resize_op, num_parallel_workers=2) data_set = data_set.batch(32, drop_remainder=True) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- get dataset size {} -----------------".format(num_iter)) - logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) - logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) + logger.info( + "-------------- get dataset size {} -----------------".format(num_iter)) + logger.info( + "-------------- item[label]: {} ---------------------".format(item["label"])) + logger.info( + "-------------- item[data]: {} ----------------------".format(item["data"])) num_iter += 1 assert num_iter == 0 @@ -206,7 +315,8 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file): """issue 888 test.""" columns_list = ["data", "label"] num_readers = 2 - data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1) + data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False, num_shards=5, shard_id=1) data = data.shuffle(2) data = data.repeat(9) num_iter = 0 @@ -226,9 +336,12 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): data_set = data_set.repeat(repeat_num) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- block reader repeat tow {} -----------------".format(num_iter)) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) num_iter += 1 assert num_iter == 20 @@ -244,10 +357,14 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem data_set = data_set.repeat(repeat_num) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) - logger.info("-------------- item[id]: {} ----------------------------".format(item["id"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- block reader repeat tow {} -----------------".format(num_iter)) + logger.info( + "-------------- item[id]: {} ----------------------------".format(item["id"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) num_iter += 1 assert num_iter == 20 @@ -256,15 +373,21 @@ def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 - data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers) + data_set = ds.MindDataset([CV_FILE_NAME + str(x) + for x in range(FILES_NUM)], columns_list, num_readers) assert data_set.get_dataset_size() == 10 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter == 10 @@ -277,11 +400,16 @@ def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): assert data_set.get_dataset_size() < 10 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter < 10 @@ -324,11 +452,16 @@ def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): assert data_set.get_dataset_size() == 30 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter == 30 if os.path.exists(CV1_FILE_NAME): @@ -346,7 +479,8 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None - os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists( + "{}.db".format(x)) else None writer = FileWriter(CV1_FILE_NAME, FILES_NUM) data = get_data(CV_DIR_NAME) cv_schema_json = {"id": {"type": "int32"}, @@ -365,11 +499,16 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): assert data_set.get_dataset_size() < 20 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter < 20 for x in paths: @@ -385,11 +524,16 @@ def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): assert data_set.get_dataset_size() == 10 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter == 10 @@ -401,10 +545,14 @@ def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file): assert data_set.get_dataset_size() == 10 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- num_iter: {} ------------------------".format(num_iter)) - logger.info("-------------- item[id]: {} ------------------------".format(item["id"])) - logger.info("-------------- item[rating]: {} --------------------".format(item["rating"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- num_iter: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[id]: {} ------------------------".format(item["id"])) + logger.info( + "-------------- item[rating]: {} --------------------".format(item["rating"])) logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format( item["input_ids"], item["input_ids"].shape)) logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format( @@ -445,10 +593,13 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_ # define map operations decode_op = vision.Decode() - resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) + resize_op = vision.Resize( + (resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) - data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4) - data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4) + data_set = data_set.map( + input_columns=["data"], operations=decode_op, num_parallel_workers=4) + data_set = data_set.map( + input_columns=["data"], operations=resize_op, num_parallel_workers=4) data_set = data_set.batch(2) assert data_set.get_dataset_size() == 5 @@ -468,11 +619,16 @@ def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file): assert data_set.get_dataset_size() == 10 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 assert num_iter == 10 @@ -486,11 +642,16 @@ def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file): data_set = data_set.repeat(repeat_num) num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- repeat two test {} ------------------------".format(num_iter)) - logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) - logger.info("-------------- item[data]: {} ----------------------------".format(item["data"])) - logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"])) - logger.info("-------------- item[label]: {} ---------------------------".format(item["label"])) + logger.info( + "-------------- repeat two test {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} ----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} -----------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ---------------------------".format(item["label"])) num_iter += 1 assert num_iter == 20 @@ -599,7 +760,8 @@ def get_mkv_data(dir_name): "id": index} data_list.append(data_json) index += 1 - logger.info('{} images are missing'.format(len(file_list) - len(data_list))) + logger.info('{} images are missing'.format( + len(file_list) - len(data_list))) return data_list @@ -686,6 +848,10 @@ def inputs(vectors, maxlen=50): def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): mindrecord_file_name = "test.mindrecord" + if os.path.exists("{}".format(mindrecord_file_name)): + os.remove("{}".format(mindrecord_file_name)) + if os.path.exists("{}.db".format(mindrecord_file_name)): + os.remove("{}.db".format(x)) data = [{"file_name": "001.jpg", "label": 4, "image1": bytes("image1 bytes abc", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'), @@ -782,7 +948,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): data_value_to_list = [] for item in data: new_data = {} - new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) + new_data['file_name'] = np.asarray( + list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) @@ -807,7 +974,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 13 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -815,7 +983,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): num_readers = 2 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"], + columns_list=["source_sos_ids", + "source_sos_mask", "target_sos_ids"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -832,7 +1001,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): num_readers = 1 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], + columns_list=[ + "image2", "source_sos_mask", "image3", "target_sos_ids"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -841,7 +1011,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 4 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -849,7 +1020,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): num_readers = 3 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_sos_ids", "image4", "source_sos_ids"], + columns_list=["target_sos_ids", + "image4", "source_sos_ids"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -858,7 +1030,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 3 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -866,7 +1039,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): num_readers = 3 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"], + columns_list=["target_sos_ids", "image5", + "image4", "image3", "source_sos_ids"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -875,7 +1049,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 5 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -883,7 +1058,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): num_readers = 1 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"], + columns_list=["target_eos_mask", "image5", + "image2", "source_sos_mask", "label"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -892,7 +1068,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 5 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -910,7 +1087,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): assert len(item) == 11 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -975,7 +1153,8 @@ def test_write_with_multi_bytes_and_MindDataset(): data_value_to_list = [] for item in data: new_data = {} - new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) + new_data['file_name'] = np.asarray( + list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) @@ -994,7 +1173,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 7 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1011,7 +1191,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 3 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1028,7 +1209,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 2 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1045,7 +1227,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 2 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1062,7 +1245,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 3 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1070,7 +1254,8 @@ def test_write_with_multi_bytes_and_MindDataset(): num_readers = 2 data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image4", "image5", "image2", "image3", "file_name"], + columns_list=["image4", "image5", + "image2", "image3", "file_name"], num_parallel_workers=num_readers, shuffle=False) assert data_set.get_dataset_size() == 6 @@ -1079,7 +1264,8 @@ def test_write_with_multi_bytes_and_MindDataset(): assert len(item) == 5 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1177,7 +1363,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 8 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1196,7 +1383,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 6 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1215,7 +1403,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 3 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1234,7 +1423,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 3 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1251,7 +1441,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 1 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 @@ -1271,7 +1462,8 @@ def test_write_with_multi_array_and_MindDataset(): assert len(item) == 8 for field in item: if isinstance(item[field], np.ndarray): - assert (item[field] == data_value_to_list[num_iter][field]).all() + assert (item[field] == + data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 diff --git a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py index 0662356eecd..5cf778c8892 100644 --- a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py @@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS CIFAR100_DIR = "../data/mindrecord/testCifar100Data" MINDRECORD_FILE = "./cifar100.mindrecord" +@pytest.fixture +def fixture_file(): + """add/remove file""" + def remove_file(x): + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + if os.path.exists("{}_test".format(x)): + os.remove("{}_test".format(x)) + if os.path.exists("{}_test.db".format(x)): + os.remove("{}_test.db".format(x)) -def test_cifar100_to_mindrecord_without_index_fields(): + remove_file(MINDRECORD_FILE) + yield "yield_fixture_data" + remove_file(MINDRECORD_FILE) + +def test_cifar100_to_mindrecord_without_index_fields(fixture_file): """test transform cifar100 dataset to mindrecord without index fields.""" cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) ret = cifar100_transformer.transform() @@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields(): assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - os.remove("{}".format(MINDRECORD_FILE + "_test")) - os.remove("{}.db".format(MINDRECORD_FILE + "_test")) - - -def test_cifar100_to_mindrecord(): +def test_cifar100_to_mindrecord(fixture_file): """test transform cifar100 dataset to mindrecord.""" cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) cifar100_transformer.transform(['fine_label', 'coarse_label']) assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - - os.remove("{}".format(MINDRECORD_FILE + "_test")) - os.remove("{}.db".format(MINDRECORD_FILE + "_test")) def read(): @@ -77,8 +82,7 @@ def read(): assert count == 4 reader.close() - -def test_cifar100_to_mindrecord_illegal_file_name(): +def test_cifar100_to_mindrecord_illegal_file_name(fixture_file): """ test transform cifar100 dataset to mindrecord when file name contains illegal character. @@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name(): cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) cifar100_transformer.transform() - -def test_cifar100_to_mindrecord_filename_start_with_space(): +def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file): """ test transform cifar10 dataset to mindrecord when file name starts with space. @@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space(): cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) cifar100_transformer.transform() - -def test_cifar100_to_mindrecord_filename_contain_space(): +def test_cifar100_to_mindrecord_filename_contain_space(fixture_file): """ test transform cifar10 dataset to mindrecord when file name contains space. @@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space(): cifar100_transformer.transform() assert os.path.exists(filename) assert os.path.exists(filename + "_test") - os.remove("{}".format(filename)) - os.remove("{}.db".format(filename)) - os.remove("{}".format(filename + "_test")) - os.remove("{}.db".format(filename + "_test")) - - -def test_cifar100_to_mindrecord_directory(): +def test_cifar100_to_mindrecord_directory(fixture_file): """ test transform cifar10 dataset to mindrecord when destination path is directory. @@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory(): CIFAR100_DIR) cifar100_transformer.transform() - -def test_cifar100_to_mindrecord_filename_equals_cifar100(): +def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): """ test transform cifar10 dataset to mindrecord when destination path equals source path. diff --git a/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py b/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py index aab17c78638..5464cc0e505 100644 --- a/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py @@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS CIFAR10_DIR = "../data/mindrecord/testCifar10Data" MINDRECORD_FILE = "./cifar10.mindrecord" +@pytest.fixture +def fixture_file(): + """add/remove file""" + def remove_file(x): + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + if os.path.exists("{}_test".format(x)): + os.remove("{}_test".format(x)) + if os.path.exists("{}_test.db".format(x)): + os.remove("{}_test.db".format(x)) -def test_cifar10_to_mindrecord_without_index_fields(): + remove_file(MINDRECORD_FILE) + yield "yield_fixture_data" + remove_file(MINDRECORD_FILE) + +@pytest.fixture +def fixture_space_file(): + """add/remove file""" + def remove_file(x): + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + if os.path.exists("{}_test".format(x)): + os.remove("{}_test".format(x)) + if os.path.exists("{}_test.db".format(x)): + os.remove("{}_test.db".format(x)) + + x = "./yes ok" + remove_file(x) + yield "yield_fixture_data" + remove_file(x) + +def test_cifar10_to_mindrecord_without_index_fields(fixture_file): """test transform cifar10 dataset to mindrecord without index fields.""" cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) cifar10_transformer.transform() assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - - os.remove("{}".format(MINDRECORD_FILE + "_test")) - os.remove("{}.db".format(MINDRECORD_FILE + "_test")) -def test_cifar10_to_mindrecord(): + +def test_cifar10_to_mindrecord(fixture_file): """test transform cifar10 dataset to mindrecord.""" cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) cifar10_transformer.transform(['label']) assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - os.remove("{}".format(MINDRECORD_FILE + "_test")) - os.remove("{}.db".format(MINDRECORD_FILE + "_test")) - - -def test_cifar10_to_mindrecord_with_return(): +def test_cifar10_to_mindrecord_with_return(fixture_file): """test transform cifar10 dataset to mindrecord.""" cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) ret = cifar10_transformer.transform(['label']) @@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return(): assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - - os.remove("{}".format(MINDRECORD_FILE + "_test")) - os.remove("{}.db".format(MINDRECORD_FILE + "_test")) def read(): @@ -90,8 +109,7 @@ def read(): assert count == 4 reader.close() - -def test_cifar10_to_mindrecord_illegal_file_name(): +def test_cifar10_to_mindrecord_illegal_file_name(fixture_file): """ test transform cifar10 dataset to mindrecord when file name contains illegal character. @@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name(): cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) cifar10_transformer.transform() - -def test_cifar10_to_mindrecord_filename_start_with_space(): +def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file): """ test transform cifar10 dataset to mindrecord when file name starts with space. @@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space(): cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) cifar10_transformer.transform() - -def test_cifar10_to_mindrecord_filename_contain_space(): +def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file): """ test transform cifar10 dataset to mindrecord when file name contains space. @@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space(): cifar10_transformer.transform() assert os.path.exists(filename) assert os.path.exists(filename + "_test") - os.remove("{}".format(filename)) - os.remove("{}.db".format(filename)) - os.remove("{}".format(filename + "_test")) - os.remove("{}.db".format(filename + "_test")) - - -def test_cifar10_to_mindrecord_directory(): +def test_cifar10_to_mindrecord_directory(fixture_file): """ test transform cifar10 dataset to mindrecord when destination path is directory. diff --git a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py index 385192bbda8..0a8cec54af1 100644 --- a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py @@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images" MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord" PARTITION_NUMBER = 4 +@pytest.fixture +def fixture_file(): + """add/remove file""" + def remove_one_file(x): + if os.path.exists(x): + os.remove(x) + def remove_file(): + x = MINDRECORD_FILE + remove_one_file(x) + x = MINDRECORD_FILE + ".db" + remove_one_file(x) + for i in range(PARTITION_NUMBER): + x = MINDRECORD_FILE + str(i) + remove_one_file(x) + x = MINDRECORD_FILE + str(i) + ".db" + remove_one_file(x) + + remove_file() + yield "yield_fixture_data" + remove_file() def read(filename): """test file reade""" @@ -38,8 +58,7 @@ def read(filename): assert count == 20 reader.close() - -def test_imagenet_to_mindrecord(): +def test_imagenet_to_mindrecord(fixture_file): """test transform imagenet dataset to mindrecord.""" imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, MINDRECORD_FILE, PARTITION_NUMBER) @@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord(): assert os.path.exists(MINDRECORD_FILE + str(i)) assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") read(MINDRECORD_FILE + "0") - for i in range(PARTITION_NUMBER): - os.remove(MINDRECORD_FILE + str(i)) - os.remove(MINDRECORD_FILE + str(i) + ".db") - -def test_imagenet_to_mindrecord_default_partition_number(): +def test_imagenet_to_mindrecord_default_partition_number(fixture_file): """ test transform imagenet dataset to mindrecord when partition number is default. @@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number(): assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + ".db") read(MINDRECORD_FILE) - os.remove("{}".format(MINDRECORD_FILE)) - os.remove("{}.db".format(MINDRECORD_FILE)) - -def test_imagenet_to_mindrecord_partition_number_0(): +def test_imagenet_to_mindrecord_partition_number_0(fixture_file): """ test transform imagenet dataset to mindrecord when partition number is 0. @@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0(): MINDRECORD_FILE, 0) imagenet_transformer.transform() - -def test_imagenet_to_mindrecord_partition_number_none(): +def test_imagenet_to_mindrecord_partition_number_none(fixture_file): """ test transform imagenet dataset to mindrecord when partition number is none. @@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none(): MINDRECORD_FILE, None) imagenet_transformer.transform() - -def test_imagenet_to_mindrecord_illegal_filename(): +def test_imagenet_to_mindrecord_illegal_filename(fixture_file): """ test transform imagenet dataset to mindrecord when file name contains illegal character. diff --git a/tests/ut/python/mindrecord/test_mindrecord_exception.py b/tests/ut/python/mindrecord/test_mindrecord_exception.py index 2eef72cebed..bb9455a6b77 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_exception.py +++ b/tests/ut/python/mindrecord/test_mindrecord_exception.py @@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord" FILES_NUM = 4 +def remove_one_file(x): + if os.path.exists(x): + os.remove(x) + +def remove_file(file_name): + x = file_name + remove_one_file(x) + x = file_name + ".db" + remove_one_file(x) + for i in range(FILES_NUM): + x = file_name + str(i) + remove_one_file(x) + x = file_name + str(i) + ".db" + remove_one_file(x) + +@pytest.fixture +def fixture_cv_file(): + """add/remove file""" + remove_file(CV_FILE_NAME) + yield "yield_fixture_data" + remove_file(CV_FILE_NAME) + +@pytest.fixture +def fixture_nlp_file(): + """add/remove file""" + remove_file(NLP_FILE_NAME) + yield "yield_fixture_data" + remove_file(NLP_FILE_NAME) def test_cv_file_writer_shard_num_none(): """test cv file writer when shard num is None.""" @@ -83,8 +111,7 @@ def test_lack_partition_and_db(): 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - -def test_lack_db(): +def test_lack_db(fixture_cv_file): """test file reader when db file does not exist.""" create_cv_mindrecord(1) os.remove("{}.db".format(CV_FILE_NAME)) @@ -94,10 +121,8 @@ def test_lack_db(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - os.remove(CV_FILE_NAME) - -def test_lack_some_partition_and_db(): +def test_lack_some_partition_and_db(fixture_cv_file): """test file reader when some partition and db do not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -110,16 +135,8 @@ def test_lack_some_partition_and_db(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) - for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) - -def test_lack_some_partition_first(): +def test_lack_some_partition_first(fixture_cv_file): """test file reader when first partition does not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -131,14 +148,8 @@ def test_lack_some_partition_first(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) - -def test_lack_some_partition_middle(): +def test_lack_some_partition_middle(fixture_cv_file): """test file reader when some partition does not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -150,14 +161,8 @@ def test_lack_some_partition_middle(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) - -def test_lack_some_partition_last(): +def test_lack_some_partition_last(fixture_cv_file): """test file reader when last partition does not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -169,14 +174,8 @@ def test_lack_some_partition_last(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) - -def test_mindpage_lack_some_partition(): +def test_mindpage_lack_some_partition(fixture_cv_file): """test page reader when some partition does not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) - -def test_lack_some_db(): +def test_lack_some_db(fixture_cv_file): """test file reader when some db does not exist.""" create_cv_mindrecord(4) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) @@ -206,11 +199,6 @@ def test_lack_some_db(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - for x in paths: - if os.path.exists("{}".format(x)): - os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): - os.remove("{}.db".format(x)) def test_invalid_mindrecord(): @@ -225,8 +213,7 @@ def test_invalid_mindrecord(): in str(err.value) os.remove(CV_FILE_NAME) - -def test_invalid_db(): +def test_invalid_db(fixture_cv_file): """test file reader when the content of db is illegal.""" create_cv_mindrecord(1) os.remove("imagenet.mindrecord.db") @@ -237,11 +224,8 @@ def test_invalid_db(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - os.remove("imagenet.mindrecord") - os.remove("imagenet.mindrecord.db") - -def test_overwrite_invalid_mindrecord(): +def test_overwrite_invalid_mindrecord(fixture_cv_file): """test file writer when overwrite invalid mindreocrd file.""" with open(CV_FILE_NAME, 'w') as f: f.write('just for test') @@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord(): assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value) - os.remove(CV_FILE_NAME) - -def test_overwrite_invalid_db(): +def test_overwrite_invalid_db(fixture_cv_file): """test file writer when overwrite invalid db file.""" with open('imagenet.mindrecord.db', 'w') as f: f.write('just for test') @@ -261,11 +243,8 @@ def test_overwrite_invalid_db(): create_cv_mindrecord(1) assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \ 'error_msg: Failed to generate index.' in str(err.value) - os.remove("imagenet.mindrecord") - os.remove("imagenet.mindrecord.db") - -def test_read_after_close(): +def test_read_after_close(fixture_cv_file): """test file reader when close read.""" create_cv_mindrecord(1) reader = FileReader(CV_FILE_NAME) @@ -275,11 +254,8 @@ def test_read_after_close(): count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 0 - os.remove(CV_FILE_NAME) - os.remove("{}.db".format(CV_FILE_NAME)) - -def test_file_read_after_read(): +def test_file_read_after_read(fixture_cv_file): """test file reader when finish read.""" create_cv_mindrecord(1) reader = FileReader(CV_FILE_NAME) @@ -295,8 +271,6 @@ def test_file_read_after_read(): cnt = cnt + 1 logger.info("#item{}: {}".format(index, x)) assert cnt == 0 - os.remove(CV_FILE_NAME) - os.remove("{}.db".format(CV_FILE_NAME)) def test_cv_file_writer_shard_num_greater_than_1000(): @@ -312,8 +286,7 @@ def test_add_index_without_add_schema(): fw.add_index(["label"]) assert 'Failed to get meta info' in str(err.value) - -def test_mindpage_pageno_pagesize_not_int(): +def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): """test page reader when some partition does not exist.""" create_cv_mindrecord(4) reader = MindPage(CV_FILE_NAME + "0") @@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int(): with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): reader.read_at_page_by_id(99999, 0, 1) - paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) - for x in range(FILES_NUM)] - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) - -def test_mindpage_filename_not_exist(): +def test_mindpage_filename_not_exist(fixture_cv_file): """test page reader when some partition does not exist.""" create_cv_mindrecord(4) reader = MindPage(CV_FILE_NAME + "0") @@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist(): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) diff --git a/tests/ut/python/mindrecord/test_mnist_to_mr.py b/tests/ut/python/mindrecord/test_mnist_to_mr.py index 55376f3dcb4..d24518ceba0 100644 --- a/tests/ut/python/mindrecord/test_mnist_to_mr.py +++ b/tests/ut/python/mindrecord/test_mnist_to_mr.py @@ -14,6 +14,7 @@ """test mnist to mindrecord tool""" import cv2 import gzip +import pytest import numpy as np import os @@ -27,6 +28,34 @@ PARTITION_NUM = 4 IMAGE_SIZE = 28 NUM_CHANNELS = 1 +@pytest.fixture +def fixture_file(): + """add/remove file""" + def remove_one_file(x): + if os.path.exists(x): + os.remove(x) + def remove_file(): + x = "mnist_train.mindrecord" + remove_one_file(x) + x = "mnist_train.mindrecord.db" + remove_one_file(x) + x = "mnist_test.mindrecord" + remove_one_file(x) + x = "mnist_test.mindrecord.db" + remove_one_file(x) + for i in range(PARTITION_NUM): + x = "mnist_train.mindrecord" + str(i) + remove_one_file(x) + x = "mnist_train.mindrecord" + str(i) + ".db" + remove_one_file(x) + x = "mnist_test.mindrecord" + str(i) + remove_one_file(x) + x = "mnist_test.mindrecord" + str(i) + ".db" + remove_one_file(x) + + remove_file() + yield "yield_fixture_data" + remove_file() def read(train_name, test_name): """test file reader""" @@ -51,7 +80,7 @@ def read(train_name, test_name): reader.close() -def test_mnist_to_mindrecord(): +def test_mnist_to_mindrecord(fixture_file): """test transform mnist dataset to mindrecord.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() @@ -60,13 +89,7 @@ def test_mnist_to_mindrecord(): read("mnist_train.mindrecord", "mnist_test.mindrecord") - os.remove("{}".format("mnist_train.mindrecord")) - os.remove("{}.db".format("mnist_train.mindrecord")) - os.remove("{}".format("mnist_test.mindrecord")) - os.remove("{}.db".format("mnist_test.mindrecord")) - - -def test_mnist_to_mindrecord_compare_data(): +def test_mnist_to_mindrecord_compare_data(fixture_file): """test transform mnist dataset to mindrecord and compare data.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() @@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data(): assert np.array(x['label']) == label reader.close() - os.remove("{}".format("mnist_train.mindrecord")) - os.remove("{}.db".format("mnist_train.mindrecord")) - os.remove("{}".format("mnist_test.mindrecord")) - os.remove("{}.db".format("mnist_test.mindrecord")) - - -def test_mnist_to_mindrecord_multi_partition(): +def test_mnist_to_mindrecord_multi_partition(fixture_file): """test transform mnist dataset to multiple mindrecord files.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) mnist_transformer.transform() read("mnist_train.mindrecord0", "mnist_test.mindrecord0") - for i in range(PARTITION_NUM): - os.remove("{}".format("mnist_train.mindrecord" + str(i))) - os.remove("{}.db".format("mnist_train.mindrecord" + str(i))) - os.remove("{}".format("mnist_test.mindrecord" + str(i))) - os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))