forked from mindspore-Ecosystem/mindspore
!1317 [MD]add compress for nlp data in mindrecord
Merge pull request !1317 from liyong126/mindrecord_compress
This commit is contained in:
commit
decf12cd0b
|
@ -112,25 +112,26 @@ Status MindRecordOp::Init() {
|
|||
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
|
||||
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas();
|
||||
// check whether schema exists, if so use the first one
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found");
|
||||
mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"];
|
||||
std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found");
|
||||
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType();
|
||||
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape();
|
||||
|
||||
bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything
|
||||
std::map<std::string, int32_t> colname_to_ind;
|
||||
for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) {
|
||||
std::string colname = it.key(); // key of the json, column name
|
||||
mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape
|
||||
for (uint32_t i = 0; i < col_names.size(); i++) {
|
||||
std::string colname = col_names[i];
|
||||
ColDescriptor col_desc;
|
||||
|
||||
TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown
|
||||
std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"];
|
||||
std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]];
|
||||
DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
|
||||
if (it_value["type"] == "bytes") { // rank = 1
|
||||
|
||||
if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1
|
||||
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1);
|
||||
} else if (it_value.find("shape") != it_value.end()) {
|
||||
std::vector<dsize_t> vec(it_value["shape"].size()); // temporary vector to hold shape
|
||||
(void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
|
||||
} else if (col_shapes[i].size() > 0) {
|
||||
std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape
|
||||
(void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin());
|
||||
t_shape = TensorShape(vec);
|
||||
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape);
|
||||
} else { // unknown shape
|
||||
|
@ -162,33 +163,10 @@ Status MindRecordOp::Init() {
|
|||
num_rows_ = shard_reader_->GetNumRows();
|
||||
// Compute how many buffers we would need to accomplish rowsPerBuffer
|
||||
buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_;
|
||||
RETURN_IF_NOT_OK(SetColumnsBlob());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::SetColumnsBlob() {
|
||||
columns_blob_ = shard_reader_->GetBlobFields().second;
|
||||
|
||||
// get the exactly blob fields by columns_to_load_
|
||||
std::vector<std::string> columns_blob_exact;
|
||||
for (auto &blob_field : columns_blob_) {
|
||||
for (auto &column : columns_to_load_) {
|
||||
if (column.compare(blob_field) == 0) {
|
||||
columns_blob_exact.push_back(blob_field);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
|
||||
int32_t iBlob = 0;
|
||||
for (auto &blob_exact : columns_blob_exact) {
|
||||
columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Destructor
|
||||
MindRecordOp::~MindRecordOp() {}
|
||||
|
||||
|
@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MindRecordOp::LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col,
|
||||
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const {
|
||||
TensorShape new_shape = TensorShape::CreateUnknownRankShape();
|
||||
const unsigned char *data = nullptr;
|
||||
|
||||
std::unique_ptr<T[]> array_data;
|
||||
std::string string_data;
|
||||
|
||||
const ColDescriptor &cur_column = data_schema_->column(i_col);
|
||||
std::string column_name = columns_to_load_[i_col];
|
||||
DataType type = cur_column.type();
|
||||
|
||||
// load blob column
|
||||
if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) {
|
||||
int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col];
|
||||
RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column));
|
||||
} else {
|
||||
switch (type.value()) {
|
||||
case DataType::DE_UINT8: {
|
||||
// For strings (Assume DE_UINT8 is reserved for strings)
|
||||
RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json));
|
||||
data = reinterpret_cast<const unsigned char *>(common::SafeCStr(string_data));
|
||||
break;
|
||||
}
|
||||
case DataType::DE_FLOAT32: {
|
||||
// For both float scalars and arrays
|
||||
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false));
|
||||
data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
break;
|
||||
}
|
||||
case DataType::DE_FLOAT64: {
|
||||
// For both double scalars and arrays
|
||||
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true));
|
||||
data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// For both integers scalars and arrays
|
||||
RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column));
|
||||
data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create Tensor with given details
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data,
|
||||
const std::vector<uint8_t> &columns_blob, const int32_t pos,
|
||||
const ColDescriptor &column) {
|
||||
const auto kColumnSize = column.type().SizeInBytes();
|
||||
if (kColumnSize == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("column size is null");
|
||||
}
|
||||
if (pos == -1) {
|
||||
if (column.hasShape()) {
|
||||
*new_shape = TensorShape::CreateUnknownRankShape();
|
||||
RETURN_IF_NOT_OK(
|
||||
column.MaterializeTensorShape(static_cast<int32_t>(columns_blob.size() / kColumnSize), new_shape));
|
||||
} else {
|
||||
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_blob.size() / kColumnSize)};
|
||||
*new_shape = TensorShape(shapeDetails);
|
||||
}
|
||||
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[0]));
|
||||
return Status::OK();
|
||||
}
|
||||
auto uint64_from_bytes = [&](int64_t pos) {
|
||||
uint64_t result = 0;
|
||||
for (uint64_t n = 0; n < kInt64Len; n++) {
|
||||
result = (result << 8) + columns_blob[pos + n];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
uint64_t iStart = 0;
|
||||
for (int32_t i = 0; i < pos; i++) {
|
||||
uint64_t num_bytes = uint64_from_bytes(iStart);
|
||||
iStart += kInt64Len + num_bytes;
|
||||
}
|
||||
uint64_t num_bytes = uint64_from_bytes(iStart);
|
||||
iStart += kInt64Len;
|
||||
if (column.hasShape()) {
|
||||
*new_shape = TensorShape::CreateUnknownRankShape();
|
||||
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_bytes / kColumnSize), new_shape));
|
||||
} else {
|
||||
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_bytes / kColumnSize)};
|
||||
*new_shape = TensorShape(shapeDetails);
|
||||
}
|
||||
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[iStart]));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) {
|
||||
if (!columns_json[column_name].is_array()) {
|
||||
T value = 0;
|
||||
RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double));
|
||||
|
||||
*new_shape = TensorShape::CreateScalar();
|
||||
*array_data = std::make_unique<T[]>(1);
|
||||
(*array_data)[0] = value;
|
||||
} else {
|
||||
if (column.hasShape()) {
|
||||
*new_shape = TensorShape(column.shape());
|
||||
} else {
|
||||
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
|
||||
*new_shape = TensorShape(shapeDetails);
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
|
||||
for (auto &element : columns_json[column_name]) {
|
||||
T value = 0;
|
||||
RETURN_IF_NOT_OK(GetFloat(&value, element, use_double));
|
||||
|
||||
(*array_data)[idx++] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) {
|
||||
if (data.is_number()) {
|
||||
*value = data;
|
||||
} else if (data.is_string()) {
|
||||
try {
|
||||
if (use_double) {
|
||||
*value = data.get<double>();
|
||||
} else {
|
||||
*value = data.get<float>();
|
||||
}
|
||||
} catch (mindrecord::json::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json, const ColDescriptor &column) {
|
||||
if (!columns_json[column_name].is_array()) {
|
||||
T value = 0;
|
||||
RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name]));
|
||||
|
||||
*new_shape = TensorShape::CreateScalar();
|
||||
*array_data = std::make_unique<T[]>(1);
|
||||
(*array_data)[0] = value;
|
||||
} else {
|
||||
if (column.hasShape()) {
|
||||
*new_shape = TensorShape(column.shape());
|
||||
} else {
|
||||
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
|
||||
*new_shape = TensorShape(shapeDetails);
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
|
||||
for (auto &element : columns_json[column_name]) {
|
||||
T value = 0;
|
||||
RETURN_IF_NOT_OK(GetInt(&value, element));
|
||||
|
||||
(*array_data)[idx++] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) {
|
||||
int64_t temp_value = 0;
|
||||
bool less_than_zero = false;
|
||||
|
||||
if (data.is_number_integer()) {
|
||||
const mindrecord::json json_zero = 0;
|
||||
if (data < json_zero) less_than_zero = true;
|
||||
temp_value = data;
|
||||
} else if (data.is_string()) {
|
||||
std::string string_value = data;
|
||||
|
||||
if (!string_value.empty() && string_value[0] == '-') {
|
||||
try {
|
||||
temp_value = std::stoll(string_value);
|
||||
less_than_zero = true;
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
temp_value = static_cast<int64_t>(std::stoull(string_value));
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
|
||||
}
|
||||
|
||||
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
|
||||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range");
|
||||
}
|
||||
*value = static_cast<T>(temp_value);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json) {
|
||||
*string_data = columns_json[column_name];
|
||||
std::vector<dsize_t> shape_details = {static_cast<dsize_t>(string_data->size())};
|
||||
*new_shape = TensorShape(shape_details);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<IOBlock> io_block;
|
||||
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
|
||||
while (io_block != nullptr) {
|
||||
if (io_block->eoe() == true) {
|
||||
if (io_block->eoe()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
|
||||
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
|
||||
continue;
|
||||
}
|
||||
if (io_block->eof() == true) {
|
||||
if (io_block->eof()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
|
||||
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
|
||||
|
@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
|
|||
if (tupled_buffer.empty()) break;
|
||||
}
|
||||
for (const auto &tupled_row : tupled_buffer) {
|
||||
std::vector<uint8_t> columnsBlob = std::get<0>(tupled_row);
|
||||
std::vector<uint8_t> columns_blob = std::get<0>(tupled_row);
|
||||
mindrecord::json columns_json = std::get<1>(tupled_row);
|
||||
TensorRow tensor_row;
|
||||
for (uint32_t j = 0; j < columns_to_load_.size(); ++j) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
|
||||
const ColDescriptor &cur_column = data_schema_->column(j);
|
||||
DataType type = cur_column.type();
|
||||
RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json));
|
||||
|
||||
tensor_row.push_back(std::move(tensor));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json));
|
||||
tensor_table->push_back(std::move(tensor_row));
|
||||
}
|
||||
}
|
||||
|
@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
|
||||
const std::vector<uint8_t> &columns_blob,
|
||||
const mindrecord::json &columns_json) const {
|
||||
switch (type.value()) {
|
||||
case DataType::DE_BOOL: {
|
||||
return LoadFeature<bool>(tensor, i_col, columns_blob, columns_json);
|
||||
Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
|
||||
const mindrecord::json &columns_json) {
|
||||
for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) {
|
||||
auto column_name = columns_to_load_[i_col];
|
||||
|
||||
// Initialize column parameters
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0;
|
||||
mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
|
||||
uint64_t column_data_type_size = 1;
|
||||
std::vector<int64_t> column_shape;
|
||||
|
||||
// Get column data
|
||||
|
||||
auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName(
|
||||
column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size,
|
||||
&column_shape);
|
||||
if (has_column == MSRStatus::FAILED) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader.");
|
||||
}
|
||||
case DataType::DE_INT8: {
|
||||
return LoadFeature<int8_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_UINT8: {
|
||||
return LoadFeature<uint8_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_INT16: {
|
||||
return LoadFeature<int16_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_UINT16: {
|
||||
return LoadFeature<uint16_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_INT32: {
|
||||
return LoadFeature<int32_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_UINT32: {
|
||||
return LoadFeature<uint32_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_INT64: {
|
||||
return LoadFeature<int64_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_UINT64: {
|
||||
return LoadFeature<uint64_t>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_FLOAT32: {
|
||||
return LoadFeature<float>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
case DataType::DE_FLOAT64: {
|
||||
return LoadFeature<double>(tensor, i_col, columns_blob, columns_json);
|
||||
}
|
||||
default: {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"mindrecord column list type does not match any known types");
|
||||
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
const ColDescriptor &column = data_schema_->column(i_col);
|
||||
DataType type = column.type();
|
||||
|
||||
// Set shape
|
||||
auto num_elements = n_bytes / column_data_type_size;
|
||||
if (column.hasShape()) {
|
||||
auto new_shape = TensorShape(column.shape());
|
||||
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
|
||||
} else {
|
||||
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_elements)};
|
||||
auto new_shape = TensorShape(shapeDetails);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
|
||||
}
|
||||
tensor_row->push_back(std::move(tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
|
@ -31,6 +32,7 @@
|
|||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "mindrecord/include/shard_column.h"
|
||||
#include "mindrecord/include/shard_error.h"
|
||||
#include "mindrecord/include/shard_reader.h"
|
||||
#include "mindrecord/include/common/shard_utils.h"
|
||||
|
@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
|
|||
|
||||
Status Init();
|
||||
|
||||
Status SetColumnsBlob();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
|
@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
|
|||
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
||||
|
||||
// Parses a single cell and puts the data into a tensor
|
||||
// @param tensor - the tensor to put the parsed data in
|
||||
// @param i_col - the id of column to parse
|
||||
// @param tensor_row - the tensor row to put the parsed data in
|
||||
// @param columns_blob - the blob data received from the reader
|
||||
// @param columns_json - the data for fields received from the reader
|
||||
template <typename T>
|
||||
Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob,
|
||||
const mindrecord::json &columns_json) const;
|
||||
|
||||
Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
|
||||
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const;
|
||||
|
||||
static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob,
|
||||
const int32_t pos, const ColDescriptor &column);
|
||||
|
||||
// Get shape and data (scalar or array) for tensor to be created (for floats and doubles)
|
||||
// @param new_shape - the shape of tensor to be created.
|
||||
// @param array_data - the array where data should be put in
|
||||
// @param column_name - name of current column to be processed
|
||||
// @param columns_json - the data for fields received from the reader
|
||||
// @param column - description of current column from schema
|
||||
// @param use_double - boolean to choose between float32 and float64
|
||||
template <typename T>
|
||||
static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double);
|
||||
|
||||
// Get shape and data (scalar or array) for tensor to be created (for integers)
|
||||
// @param new_shape - the shape of tensor to be created.
|
||||
// @param array_data - the array where data should be put in
|
||||
// @param column_name - name of current column to be processed
|
||||
// @param columns_json - the data for fields received from the reader
|
||||
// @param column - description of current column from schema
|
||||
template <typename T>
|
||||
static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json, const ColDescriptor &column);
|
||||
|
||||
static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
|
||||
const mindrecord::json &columns_json);
|
||||
|
||||
// Get a single float value from the given json
|
||||
// @param value - the float to put the value in
|
||||
// @param arrayData - the given json containing the float
|
||||
// @param use_double - boolean to choose between float32 and float64
|
||||
template <typename T>
|
||||
static Status GetFloat(T *value, const mindrecord::json &data, bool use_double);
|
||||
|
||||
// Get a single integer value from the given json
|
||||
// @param value - the integer to put the value in
|
||||
// @param arrayData - the given json containing the integer
|
||||
template <typename T>
|
||||
static Status GetInt(T *value, const mindrecord::json &data);
|
||||
Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
|
||||
const mindrecord::json &columns_json);
|
||||
|
||||
Status FetchBlockBuffer(const int32_t &buffer_id);
|
||||
|
||||
|
|
|
@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
|
|||
.def("launch", &ShardReader::Launch)
|
||||
.def("get_header", &ShardReader::GetShardHeader)
|
||||
.def("get_blob_fields", &ShardReader::GetBlobFields)
|
||||
.def("get_next",
|
||||
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy)
|
||||
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
|
||||
ShardReader::GetNextPy)
|
||||
.def("finish", &ShardReader::Finish)
|
||||
.def("close", &ShardReader::Close);
|
||||
}
|
||||
|
|
|
@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
|
|||
|
||||
enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel };
|
||||
|
||||
const char kVersion[] = "3.0";
|
||||
const std::vector<std::string> kSupportedVersion = {"2.0", kVersion};
|
||||
|
||||
enum ShardType {
|
||||
kNLP = 0,
|
||||
kCV = 1,
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_
|
||||
#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "mindrecord/include/shard_header.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
const uint64_t kUnsignedOne = 1;
|
||||
const uint64_t kBitsOfByte = 8;
|
||||
const uint64_t kDataTypeBits = 2;
|
||||
const uint64_t kNumDataOfByte = 4;
|
||||
const uint64_t kBytesOfColumnLen = 4;
|
||||
const uint64_t kDataTypeBitMask = 3;
|
||||
const uint64_t kDataTypes = 6;
|
||||
|
||||
enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type };
|
||||
|
||||
enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound };
|
||||
|
||||
enum ColumnDataType {
|
||||
ColumnBytes = 0,
|
||||
ColumnString = 1,
|
||||
ColumnInt32 = 2,
|
||||
ColumnInt64 = 3,
|
||||
ColumnFloat32 = 4,
|
||||
ColumnFloat64 = 5,
|
||||
ColumnNoDataType = 6
|
||||
};
|
||||
|
||||
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
|
||||
const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8};
|
||||
|
||||
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32",
|
||||
"int64", "float32", "float64"};
|
||||
|
||||
const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
|
||||
{"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32},
|
||||
{"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}};
|
||||
|
||||
class ShardColumn {
|
||||
public:
|
||||
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
|
||||
|
||||
~ShardColumn() = default;
|
||||
|
||||
/// \brief get column value by column name
|
||||
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
||||
const json &columns_json, const unsigned char **data,
|
||||
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
|
||||
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
|
||||
std::vector<int64_t> *column_shape);
|
||||
|
||||
/// \brief compress blob
|
||||
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
|
||||
|
||||
/// \brief check if blob compressed
|
||||
bool CheckCompressBlob() const { return has_compress_blob_; }
|
||||
|
||||
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
|
||||
|
||||
std::vector<std::string> GetColumnName() { return column_name_; }
|
||||
|
||||
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
|
||||
|
||||
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
|
||||
|
||||
/// \brief get column value from blob
|
||||
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
||||
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
|
||||
uint64_t *n_bytes);
|
||||
|
||||
private:
|
||||
/// \brief get column value from json
|
||||
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
|
||||
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
|
||||
|
||||
/// \brief get float value from json
|
||||
template <typename T>
|
||||
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
|
||||
|
||||
/// \brief get integer value from json
|
||||
template <typename T>
|
||||
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
|
||||
|
||||
/// \brief get column offset address and size from blob
|
||||
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
|
||||
uint64_t *num_bytes, uint64_t *shift_idx);
|
||||
|
||||
/// \brief check if column name is available
|
||||
ColumnCategory CheckColumnName(const std::string &column_name);
|
||||
|
||||
/// \brief compress integer column
|
||||
static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type);
|
||||
|
||||
/// \brief uncompress integer array column
|
||||
template <typename T>
|
||||
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
|
||||
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
|
||||
|
||||
/// \brief convert big-endian bytes to unsigned int
|
||||
/// \param bytes_array bytes array
|
||||
/// \param pos shift address in bytes array
|
||||
/// \param i_type integer type
|
||||
/// \return unsigned int
|
||||
static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
|
||||
const IntegerType &i_type);
|
||||
|
||||
/// \brief convert unsigned int to big-endian bytes
|
||||
/// \param value integer value
|
||||
/// \param i_type integer type
|
||||
/// \return bytes
|
||||
static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type);
|
||||
|
||||
/// \brief convert unsigned int to little-endian bytes
|
||||
/// \param value integer value
|
||||
/// \param i_type integer type
|
||||
/// \return bytes
|
||||
static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type);
|
||||
|
||||
/// \brief convert unsigned int to little-endian bytes
|
||||
/// \param bytes_array bytes array
|
||||
/// \param pos shift address in bytes array
|
||||
/// \param src_i_type source integer typ0e
|
||||
/// \param dst_i_type (output), destination integer type
|
||||
/// \return integer
|
||||
static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
|
||||
const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr);
|
||||
|
||||
private:
|
||||
std::vector<std::string> column_name_; // column name list
|
||||
std::vector<ColumnDataType> column_data_type_; // column data type list
|
||||
std::vector<std::vector<int64_t>> column_shape_; // column shape list
|
||||
std::unordered_map<string, uint64_t> column_name_id_; // column name id map
|
||||
std::vector<std::string> blob_column_; // blob column list
|
||||
std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map
|
||||
bool has_compress_blob_; // if has compress blob
|
||||
uint64_t num_blob_column_; // number of blob columns
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_
|
|
@ -118,8 +118,6 @@ class ShardHeader {
|
|||
|
||||
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
|
||||
|
||||
const string GetVersion() { return version_; }
|
||||
|
||||
std::vector<std::string> SerializeHeader();
|
||||
|
||||
MSRStatus PagesToFile(const std::string dump_file_name);
|
||||
|
@ -175,7 +173,6 @@ class ShardHeader {
|
|||
uint32_t shard_count_;
|
||||
uint64_t header_size_;
|
||||
uint64_t page_size_;
|
||||
string version_ = "2.0";
|
||||
|
||||
std::shared_ptr<Index> index_;
|
||||
std::vector<std::string> shard_addresses_;
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include <vector>
|
||||
#include "mindrecord/include/common/shard_utils.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_column.h"
|
||||
#include "mindrecord/include/shard_error.h"
|
||||
#include "mindrecord/include/shard_index_generator.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
|
@ -111,6 +112,10 @@ class ShardReader {
|
|||
/// \return the metadata
|
||||
std::shared_ptr<ShardHeader> GetShardHeader() const;
|
||||
|
||||
/// \brief aim to get columns context
|
||||
/// \return the columns
|
||||
std::shared_ptr<ShardColumn> get_shard_column() const;
|
||||
|
||||
/// \brief get the number of shards
|
||||
/// \return # of shards
|
||||
int GetShardCount() const;
|
||||
|
@ -185,7 +190,7 @@ class ShardReader {
|
|||
|
||||
/// \brief return a batch, given that one is ready, python API
|
||||
/// \return a batch of images and image data
|
||||
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy();
|
||||
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();
|
||||
|
||||
/// \brief get blob filed list
|
||||
/// \return blob field list
|
||||
|
@ -295,16 +300,18 @@ class ShardReader {
|
|||
/// \brief get number of classes
|
||||
int64_t GetNumClasses(const std::string &category_field);
|
||||
|
||||
/// \brief get meta of header
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
|
||||
/// \brief get exactly blob fields data by indices
|
||||
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
|
||||
std::vector<uint32_t> &ordered_selected_columns_index);
|
||||
|
||||
/// \brief extract uncompressed data based on column list
|
||||
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
|
||||
|
||||
protected:
|
||||
uint64_t header_size_; // header size
|
||||
uint64_t page_size_; // page size
|
||||
int shard_count_; // number of shards
|
||||
std::shared_ptr<ShardHeader> shard_header_; // shard header
|
||||
std::shared_ptr<ShardColumn> shard_column_; // shard column
|
||||
|
||||
std::vector<sqlite3 *> database_paths_; // sqlite handle list
|
||||
std::vector<string> file_paths_; // file paths
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include "mindrecord/include/common/shard_utils.h"
|
||||
#include "mindrecord/include/shard_column.h"
|
||||
#include "mindrecord/include/shard_error.h"
|
||||
#include "mindrecord/include/shard_header.h"
|
||||
#include "mindrecord/include/shard_index.h"
|
||||
|
@ -242,7 +243,8 @@ class ShardWriter {
|
|||
|
||||
std::vector<std::string> file_paths_; // file paths
|
||||
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
|
||||
std::shared_ptr<ShardHeader> shard_header_; // shard headers
|
||||
std::shared_ptr<ShardHeader> shard_header_; // shard header
|
||||
std::shared_ptr<ShardColumn> shard_column_; // shard columns
|
||||
|
||||
std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info
|
||||
|
||||
|
|
|
@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
header_size_ = shard_header_->GetHeaderSize();
|
||||
page_size_ = shard_header_->GetPageSize();
|
||||
// version < 3.0
|
||||
if (first_meta_data["version"] < kVersion) {
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
|
||||
} else {
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
|
||||
}
|
||||
num_rows_ = 0;
|
||||
auto row_group_summary = ReadRowGroupSummary();
|
||||
for (const auto &rg : row_group_summary) {
|
||||
|
@ -226,6 +232,8 @@ void ShardReader::Close() {
|
|||
|
||||
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
|
||||
|
||||
std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; }
|
||||
|
||||
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
|
||||
|
||||
int ShardReader::GetNumRows() const { return num_rows_; }
|
||||
|
@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
|
||||
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
|
||||
std::vector<uint8_t> exactly_blob_fields_bytes;
|
||||
|
||||
auto uint64_from_bytes = [&](int64_t pos) {
|
||||
uint64_t result = 0;
|
||||
for (uint64_t n = 0; n < kInt64Len; n++) {
|
||||
result = (result << 8) + blob_fields_bytes[pos + n];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// get the exactly blob fields
|
||||
uint32_t current_index = 0;
|
||||
uint64_t current_offset = 0;
|
||||
uint64_t data_len = uint64_from_bytes(current_offset);
|
||||
while (current_offset < blob_fields_bytes.size()) {
|
||||
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
|
||||
[¤t_index](uint32_t &index) { return index == current_index; })) {
|
||||
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
|
||||
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
|
||||
}
|
||||
current_index++;
|
||||
current_offset += kInt64Len + data_len;
|
||||
data_len = uint64_from_bytes(current_offset);
|
||||
}
|
||||
|
||||
return exactly_blob_fields_bytes;
|
||||
}
|
||||
|
||||
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
|
||||
// All tasks are done
|
||||
if (task_id >= static_cast<int>(tasks_.Size())) {
|
||||
|
@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
|
||||
}
|
||||
|
||||
// extract the exactly blob bytes by selected columns
|
||||
std::vector<uint8_t> images_with_exact_columns;
|
||||
if (selected_columns_.size() == 0) {
|
||||
images_with_exact_columns = images;
|
||||
} else {
|
||||
auto blob_fields = GetBlobFields();
|
||||
|
||||
std::vector<uint32_t> ordered_selected_columns_index;
|
||||
uint32_t index = 0;
|
||||
for (auto &blob_field : blob_fields.second) {
|
||||
for (auto &field : selected_columns_) {
|
||||
if (field.compare(blob_field) == 0) {
|
||||
ordered_selected_columns_index.push_back(index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
if (ordered_selected_columns_index.size() != 0) {
|
||||
// extract the images
|
||||
if (blob_fields.second.size() == 1) {
|
||||
if (ordered_selected_columns_index.size() == 1) {
|
||||
images_with_exact_columns = images;
|
||||
}
|
||||
} else {
|
||||
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver batch data to output map
|
||||
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
|
||||
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
|
||||
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
|
||||
|
||||
return std::make_pair(SUCCESS, std::move(batch));
|
||||
}
|
||||
|
||||
|
@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
|
|||
return std::move(ret.second);
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() {
|
||||
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob(
|
||||
const std::vector<uint8_t> &raw_blob_data) {
|
||||
auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
|
||||
auto blob_fields = GetBlobFields().second;
|
||||
std::vector<std::vector<uint8_t>> blob_data;
|
||||
for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
|
||||
if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0;
|
||||
auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << ".";
|
||||
return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())};
|
||||
}
|
||||
if (data == nullptr) {
|
||||
data = reinterpret_cast<const unsigned char *>(data_ptr.get());
|
||||
}
|
||||
std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
|
||||
blob_data.push_back(column);
|
||||
}
|
||||
return {SUCCESS, blob_data};
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() {
|
||||
auto res = GetNext();
|
||||
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData;
|
||||
std::transform(res.begin(), res.end(), std::back_inserter(jsonData),
|
||||
[](const std::tuple<std::vector<uint8_t>, json> &item) {
|
||||
vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data;
|
||||
std::transform(res.begin(), res.end(), std::back_inserter(data),
|
||||
[this](const std::tuple<std::vector<uint8_t>, json> &item) {
|
||||
auto &j = std::get<1>(item);
|
||||
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
|
||||
return std::make_tuple(std::get<0>(item), std::move(obj));
|
||||
auto ret = UnCompressBlob(std::get<0>(item));
|
||||
return std::make_tuple(ret.second, std::move(obj));
|
||||
});
|
||||
return jsonData;
|
||||
return data;
|
||||
}
|
||||
|
||||
void ShardReader::Reset() {
|
||||
|
|
|
@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
MS_LOG(ERROR) << "Open file failed";
|
||||
return FAILED;
|
||||
}
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
|
|||
shard_header_ = header_data;
|
||||
shard_header_->SetHeaderSize(header_size_);
|
||||
shard_header_->SetPageSize(page_size_);
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
|
|||
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// compress blob
|
||||
if (shard_column_->CheckCompressBlob()) {
|
||||
for (auto &blob : blob_data) {
|
||||
blob = shard_column_->CompressBlob(blob);
|
||||
}
|
||||
}
|
||||
|
||||
// Add 4-bytes dummy blob data if no any blob fields
|
||||
if (blob_data.size() == 0 && raw_data.size() > 0) {
|
||||
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
|
||||
|
|
|
@ -0,0 +1,473 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindrecord/include/shard_column.h"
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "mindrecord/include/common/shard_utils.h"
|
||||
#include "mindrecord/include/shard_error.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
|
||||
auto first_schema = shard_header->GetSchemas()[0];
|
||||
auto schema = first_schema->GetSchema()["schema"];
|
||||
|
||||
bool has_integer_array = false;
|
||||
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
|
||||
const std::string &column_name = it.key();
|
||||
column_name_.push_back(column_name);
|
||||
|
||||
json it_value = it.value();
|
||||
|
||||
std::string str_type = it_value["type"];
|
||||
column_data_type_.push_back(ColumnDataTypeMap.at(str_type));
|
||||
if (it_value.find("shape") != it_value.end()) {
|
||||
std::vector<int64_t> vec(it_value["shape"].size());
|
||||
std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
|
||||
column_shape_.push_back(vec);
|
||||
if (str_type == "int32" || str_type == "int64") {
|
||||
has_integer_array = true;
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> vec = {};
|
||||
column_shape_.push_back(vec);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < column_name_.size(); i++) {
|
||||
column_name_id_[column_name_[i]] = i;
|
||||
}
|
||||
|
||||
auto blob_fields = first_schema->GetBlobFields();
|
||||
|
||||
for (const auto &field : blob_fields) {
|
||||
blob_column_.push_back(field);
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < blob_column_.size(); i++) {
|
||||
blob_column_id_[blob_column_[i]] = i;
|
||||
}
|
||||
|
||||
has_compress_blob_ = (compress_integer && has_integer_array);
|
||||
num_blob_column_ = blob_column_.size();
|
||||
}
|
||||
|
||||
MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
||||
const json &columns_json, const unsigned char **data,
|
||||
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
|
||||
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
|
||||
std::vector<int64_t> *column_shape) {
|
||||
// Skip if column not found
|
||||
auto column_category = CheckColumnName(column_name);
|
||||
if (column_category == ColumnNotFound) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Get data type and size
|
||||
auto column_id = column_name_id_[column_name];
|
||||
*column_data_type = column_data_type_[column_id];
|
||||
*column_data_type_size = ColumnDataTypeSize[*column_data_type];
|
||||
*column_shape = column_shape_[column_id];
|
||||
|
||||
// Retrieve value from json
|
||||
if (column_category == ColumnInRaw) {
|
||||
if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) {
|
||||
MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << ".";
|
||||
return FAILED;
|
||||
}
|
||||
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// Retrieve value from blob
|
||||
if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) {
|
||||
MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << ".";
|
||||
return FAILED;
|
||||
}
|
||||
if (*data == nullptr) {
|
||||
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
|
||||
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
|
||||
auto column_id = column_name_id_[column_name];
|
||||
auto column_data_type = column_data_type_[column_id];
|
||||
|
||||
// Initialize num bytes
|
||||
*n_bytes = ColumnDataTypeSize[column_data_type];
|
||||
auto json_column_value = columns_json[column_name];
|
||||
switch (column_data_type) {
|
||||
case ColumnFloat32: {
|
||||
return GetFloat<float>(data_ptr, json_column_value, false);
|
||||
}
|
||||
case ColumnFloat64: {
|
||||
return GetFloat<double>(data_ptr, json_column_value, true);
|
||||
}
|
||||
case ColumnInt32: {
|
||||
return GetInt<int32_t>(data_ptr, json_column_value);
|
||||
}
|
||||
case ColumnInt64: {
|
||||
return GetInt<int64_t>(data_ptr, json_column_value);
|
||||
}
|
||||
default: {
|
||||
// Convert string to c_str
|
||||
std::string tmp_string = json_column_value;
|
||||
*n_bytes = tmp_string.size();
|
||||
auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string));
|
||||
*data_ptr = std::make_unique<unsigned char[]>(*n_bytes);
|
||||
for (uint32_t i = 0; i < *n_bytes; i++) {
|
||||
(*data_ptr)[i] = *(data + i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
|
||||
bool use_double) {
|
||||
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
|
||||
if (!json_column_value.is_string() && !json_column_value.is_number()) {
|
||||
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
|
||||
return FAILED;
|
||||
}
|
||||
if (json_column_value.is_number()) {
|
||||
array_data[0] = json_column_value;
|
||||
} else {
|
||||
// Convert string to float
|
||||
try {
|
||||
if (use_double) {
|
||||
array_data[0] = json_column_value.get<double>();
|
||||
} else {
|
||||
array_data[0] = json_column_value.get<float>();
|
||||
}
|
||||
} catch (json::exception &e) {
|
||||
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
|
||||
for (uint32_t i = 0; i < sizeof(T); i++) {
|
||||
(*data_ptr)[i] = *(data + i);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
|
||||
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
|
||||
int64_t temp_value;
|
||||
bool less_than_zero = false;
|
||||
|
||||
if (json_column_value.is_number_integer()) {
|
||||
const json json_zero = 0;
|
||||
if (json_column_value < json_zero) less_than_zero = true;
|
||||
temp_value = json_column_value;
|
||||
} else if (json_column_value.is_string()) {
|
||||
std::string string_value = json_column_value;
|
||||
|
||||
if (!string_value.empty() && string_value[0] == '-') {
|
||||
try {
|
||||
temp_value = std::stoll(string_value);
|
||||
less_than_zero = true;
|
||||
} catch (std::invalid_argument &e) {
|
||||
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
|
||||
return FAILED;
|
||||
} catch (std::out_of_range &e) {
|
||||
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
temp_value = static_cast<int64_t>(std::stoull(string_value));
|
||||
} catch (std::invalid_argument &e) {
|
||||
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
|
||||
return FAILED;
|
||||
} catch (std::out_of_range &e) {
|
||||
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Conversion to int failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
|
||||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
|
||||
MS_LOG(ERROR) << "Conversion to int failed. Out of range";
|
||||
return FAILED;
|
||||
}
|
||||
array_data[0] = static_cast<T>(temp_value);
|
||||
|
||||
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
|
||||
for (uint32_t i = 0; i < sizeof(T); i++) {
|
||||
(*data_ptr)[i] = *(data + i);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
||||
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
|
||||
uint64_t *n_bytes) {
|
||||
uint64_t offset_address = 0;
|
||||
auto column_id = column_name_id_[column_name];
|
||||
if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto column_data_type = column_data_type_[column_id];
|
||||
if (has_compress_blob_ && column_data_type == ColumnInt32) {
|
||||
if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
} else if (has_compress_blob_ && column_data_type == ColumnInt64) {
|
||||
if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
*data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
|
||||
auto it_column = column_name_id_.find(column_name);
|
||||
if (it_column == column_name_id_.end()) {
|
||||
return ColumnNotFound;
|
||||
}
|
||||
auto it_blob = blob_column_id_.find(column_name);
|
||||
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) {
|
||||
// Skip if no compress columns
|
||||
if (!CheckCompressBlob()) return blob;
|
||||
|
||||
std::vector<uint8_t> dst_blob;
|
||||
uint64_t i_src = 0;
|
||||
for (int64_t i = 0; i < num_blob_column_; i++) {
|
||||
// Get column data type
|
||||
auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]];
|
||||
auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type;
|
||||
|
||||
// Compress and return is blob has 1 column only
|
||||
if (num_blob_column_ == 1) {
|
||||
return CompressInt(blob, int_type);
|
||||
}
|
||||
|
||||
// Just copy and continue if column dat type is not int32/int64
|
||||
uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type);
|
||||
if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) {
|
||||
dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes);
|
||||
i_src += kInt64Len + num_bytes;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get column slice in source blob
|
||||
std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes);
|
||||
// Compress column
|
||||
auto dst_blob_slice = CompressInt(blob_slice, int_type);
|
||||
// Get new column size
|
||||
auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type);
|
||||
// Append new colmn size
|
||||
dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end());
|
||||
// Append new colmn data
|
||||
dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end());
|
||||
i_src += kInt64Len + num_bytes;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
|
||||
return dst_blob;
|
||||
}
|
||||
|
||||
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
|
||||
uint64_t i_size = kUnsignedOne << int_type;
|
||||
// Get number of elements
|
||||
uint64_t src_n_int = src_bytes.size() / i_size;
|
||||
// Calculate bitmap size (bytes)
|
||||
uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte;
|
||||
|
||||
// Initilize destination blob, more space than needed, will be resized
|
||||
vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0);
|
||||
|
||||
// Write number of elements to destination blob
|
||||
vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type);
|
||||
for (uint64_t n = 0; n < kBytesOfColumnLen; n++) {
|
||||
dst_bytes[n] = size_by_bytes[n];
|
||||
}
|
||||
|
||||
// Write compressed int
|
||||
uint64_t i_dst = kBytesOfColumnLen + bitmap_size;
|
||||
for (uint64_t i = 0; i < src_n_int; i++) {
|
||||
// Initialize destination data type
|
||||
IntegerType dst_int_type = kInt8Type;
|
||||
// Shift to next int position
|
||||
uint64_t pos = i * (kUnsignedOne << int_type);
|
||||
// Narrow down this int
|
||||
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
|
||||
|
||||
// Write this int to destination blob
|
||||
uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n);
|
||||
auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type);
|
||||
for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) {
|
||||
dst_bytes[i_dst++] = temp_bytes[j];
|
||||
}
|
||||
|
||||
// Update date type in bit map
|
||||
dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |=
|
||||
(dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte))));
|
||||
}
|
||||
// Resize destination blob
|
||||
dst_bytes.resize(i_dst);
|
||||
MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << ".";
|
||||
return dst_bytes;
|
||||
}
|
||||
|
||||
MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
|
||||
uint64_t *num_bytes, uint64_t *shift_idx) {
|
||||
if (num_blob_column_ == 1) {
|
||||
*num_bytes = columns_blob.size();
|
||||
*shift_idx = 0;
|
||||
return SUCCESS;
|
||||
}
|
||||
auto blob_id = blob_column_id_[column_name_[column_id]];
|
||||
|
||||
for (int32_t i = 0; i < blob_id; i++) {
|
||||
*shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
|
||||
}
|
||||
*num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
|
||||
|
||||
(*shift_idx) += kInt64Len;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
|
||||
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes,
|
||||
uint64_t shift_idx) {
|
||||
auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
|
||||
*num_bytes = sizeof(T) * num_elements;
|
||||
|
||||
// Parse integer array
|
||||
uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte;
|
||||
auto array_data = std::make_unique<T[]>(num_elements);
|
||||
|
||||
for (uint64_t i = 0; i < num_elements; i++) {
|
||||
uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte];
|
||||
uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask;
|
||||
auto mr_int_type = static_cast<IntegerType>(i_type);
|
||||
int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type);
|
||||
i_source += (kUnsignedOne << i_type);
|
||||
array_data[i] = static_cast<T>(i64);
|
||||
}
|
||||
|
||||
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
|
||||
*data_ptr = std::make_unique<unsigned char[]>(*num_bytes);
|
||||
memcpy(data_ptr->get(), data, *num_bytes);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
|
||||
const IntegerType &i_type) {
|
||||
uint64_t result = 0;
|
||||
for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) {
|
||||
result = (result << kBitsOfByte) + bytes_array[pos + i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) {
|
||||
uint64_t n_bytes = kUnsignedOne << i_type;
|
||||
std::vector<uint8_t> result(n_bytes, 0);
|
||||
for (uint64_t i = 0; i < n_bytes; i++) {
|
||||
result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max();
|
||||
value >>= kBitsOfByte;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) {
|
||||
uint64_t n_bytes = kUnsignedOne << i_type;
|
||||
std::vector<uint8_t> result(n_bytes, 0);
|
||||
for (uint64_t i = 0; i < n_bytes; i++) {
|
||||
result[i] = value & std::numeric_limits<uint8_t>::max();
|
||||
value >>= kBitsOfByte;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
|
||||
const IntegerType &src_i_type, IntegerType *dst_i_type) {
|
||||
uint64_t u_temp = 0;
|
||||
for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) {
|
||||
u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i];
|
||||
}
|
||||
|
||||
int64_t i_out;
|
||||
switch (src_i_type) {
|
||||
case kInt8Type: {
|
||||
i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max());
|
||||
break;
|
||||
}
|
||||
case kInt16Type: {
|
||||
i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max());
|
||||
break;
|
||||
}
|
||||
case kInt32Type: {
|
||||
i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max());
|
||||
break;
|
||||
}
|
||||
case kInt64Type: {
|
||||
i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
i_out = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (!dst_i_type) {
|
||||
return i_out;
|
||||
}
|
||||
|
||||
if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) &&
|
||||
i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) {
|
||||
*dst_i_type = kInt8Type;
|
||||
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) &&
|
||||
i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) {
|
||||
*dst_i_type = kInt16Type;
|
||||
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
|
||||
i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
|
||||
*dst_i_type = kInt32Type;
|
||||
} else {
|
||||
*dst_i_type = kInt64Type;
|
||||
}
|
||||
return i_out;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
|
@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
|
|||
json header;
|
||||
header = ret.second;
|
||||
header["shard_addresses"] = realAddresses;
|
||||
if (header["version"] != version_) {
|
||||
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
|
||||
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
|
||||
<< ", lib version is: " << version_;
|
||||
<< ", lib version is: " << kVersion;
|
||||
thread_status = true;
|
||||
return;
|
||||
}
|
||||
|
@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
|
|||
s += "\"shard_addresses\":" + address + ",";
|
||||
s += "\"shard_id\":" + std::to_string(shardId) + ",";
|
||||
s += "\"statistics\":" + stats + ",";
|
||||
s += "\"version\":\"" + version_ + "\"";
|
||||
s += "\"version\":\"" + std::string(kVersion) + "\"";
|
||||
s += "}";
|
||||
header.emplace_back(s);
|
||||
}
|
||||
|
|
|
@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
|
|||
if not blob_fields:
|
||||
return raw
|
||||
|
||||
# Get the order preserving sequence of columns in blob
|
||||
ordered_columns = []
|
||||
loaded_columns = []
|
||||
if columns:
|
||||
for blob_field in blob_fields:
|
||||
if blob_field in columns:
|
||||
ordered_columns.append(blob_field)
|
||||
for column in columns:
|
||||
if column in blob_fields:
|
||||
loaded_columns.append(column)
|
||||
else:
|
||||
ordered_columns = blob_fields
|
||||
|
||||
blob_bytes = bytes(blob)
|
||||
loaded_columns = blob_fields
|
||||
|
||||
def _render_raw(field, blob_data):
|
||||
data_type = schema[field]['type']
|
||||
|
@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
|
|||
else:
|
||||
raw[field] = blob_data
|
||||
|
||||
if len(blob_fields) == 1:
|
||||
if len(ordered_columns) == 1:
|
||||
_render_raw(blob_fields[0], blob_bytes)
|
||||
return raw
|
||||
return raw
|
||||
|
||||
def _int_from_bytes(xbytes: bytes) -> int:
|
||||
return int.from_bytes(xbytes, 'big')
|
||||
|
||||
def _blob_at_position(pos):
|
||||
start = 0
|
||||
for _ in range(pos):
|
||||
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
|
||||
start += 8 + n_bytes
|
||||
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
|
||||
start += 8
|
||||
return blob_bytes[start : start + n_bytes]
|
||||
|
||||
for i, blob_field in enumerate(ordered_columns):
|
||||
_render_raw(blob_field, _blob_at_position(i))
|
||||
for i, blob_field in enumerate(loaded_columns):
|
||||
_render_raw(blob_field, bytes(blob[i]))
|
||||
return raw
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -35,6 +35,7 @@ CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
|
|||
CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
|
||||
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
|
||||
OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord"
|
||||
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
|
||||
NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt"
|
||||
|
||||
|
@ -46,7 +47,8 @@ def add_and_remove_cv_file():
|
|||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
|
@ -96,13 +98,105 @@ def add_and_remove_nlp_file():
|
|||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_and_remove_nlp_compress_file():
|
||||
"""add/remove nlp file"""
|
||||
paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
|
||||
data = []
|
||||
for row_id in range(16):
|
||||
data.append({
|
||||
"label": row_id,
|
||||
"array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
|
||||
255, 256, -32768, 32767, -32769, 32768, -2147483648,
|
||||
2147483647], dtype=np.int32), [-1]),
|
||||
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
|
||||
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
"array_c": str.encode("nlp data"),
|
||||
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
|
||||
})
|
||||
nlp_schema_json = {"label": {"type": "int32"},
|
||||
"array_a": {"type": "int32",
|
||||
"shape": [-1]},
|
||||
"array_b": {"type": "int64",
|
||||
"shape": [1, -1]},
|
||||
"array_c": {"type": "bytes"},
|
||||
"array_d": {"type": "int64",
|
||||
"shape": [2, -1]}
|
||||
}
|
||||
writer.set_header_size(1 << 14)
|
||||
writer.set_page_size(1 << 15)
|
||||
writer.add_schema(nlp_schema_json, "nlp_schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
yield "yield_nlp_data"
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_nlp_compress_data(add_and_remove_nlp_compress_file):
|
||||
"""tutorial for nlp minderdataset."""
|
||||
data = []
|
||||
for row_id in range(16):
|
||||
data.append({
|
||||
"label": row_id,
|
||||
"array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
|
||||
255, 256, -32768, 32767, -32769, 32768, -2147483648,
|
||||
2147483647], dtype=np.int32), [-1]),
|
||||
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
|
||||
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
"array_c": str.encode("nlp data"),
|
||||
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
|
||||
})
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(
|
||||
NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
|
||||
assert data_set.get_dataset_size() == 16
|
||||
num_iter = 0
|
||||
for x, item in zip(data, data_set.create_dict_iterator()):
|
||||
assert (item["array_a"] == x["array_a"]).all()
|
||||
assert (item["array_b"] == x["array_b"]).all()
|
||||
assert item["array_c"].tobytes() == x["array_c"]
|
||||
assert (item["array_d"] == x["array_d"]).all()
|
||||
assert item["label"] == x["label"]
|
||||
num_iter += 1
|
||||
assert num_iter == 16
|
||||
|
||||
|
||||
def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file):
|
||||
"""tutorial for nlp minderdataset."""
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(
|
||||
NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
|
||||
old_data_set = ds.MindDataset(
|
||||
OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
|
||||
assert old_data_set.get_dataset_size() == 16
|
||||
num_iter = 0
|
||||
for x, item in zip(old_data_set.create_dict_iterator(), data_set.create_dict_iterator()):
|
||||
assert (item["array_a"] == x["array_a"]).all()
|
||||
assert (item["array_b"] == x["array_b"]).all()
|
||||
assert (item["array_c"] == x["array_c"]).all()
|
||||
assert (item["array_d"] == x["array_d"]).all()
|
||||
assert item["label"] == x["label"]
|
||||
num_iter += 1
|
||||
assert num_iter == 16
|
||||
|
||||
|
||||
def test_cv_minddataset_writer_tutorial():
|
||||
"""tutorial for cv dataset writer."""
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
|
||||
|
@ -127,8 +221,10 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
|||
num_shards=num_shards, shard_id=partition_id)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
return num_iter
|
||||
|
||||
|
@ -147,9 +243,12 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
|
|||
data_set = data_set.repeat(repeat_num)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
|
@ -163,17 +262,22 @@ def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
|
|||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
decode_op = vision.Decode()
|
||||
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
data_set = data_set.map(
|
||||
input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
||||
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.map(input_columns="data",
|
||||
operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.batch(2)
|
||||
data_set = data_set.repeat(2)
|
||||
num_iter = 0
|
||||
labels = []
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
labels.append(item["label"])
|
||||
assert num_iter == 10
|
||||
|
@ -189,15 +293,20 @@ def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
|
|||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
decode_op = vision.Decode()
|
||||
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
data_set = data_set.map(
|
||||
input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
||||
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.map(input_columns="data",
|
||||
operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.batch(32, drop_remainder=True)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 0
|
||||
|
||||
|
@ -206,7 +315,8 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
|
|||
"""issue 888 test."""
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 2
|
||||
data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
|
||||
data = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False, num_shards=5, shard_id=1)
|
||||
data = data.shuffle(2)
|
||||
data = data.repeat(9)
|
||||
num_iter = 0
|
||||
|
@ -226,9 +336,12 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
|
|||
data_set = data_set.repeat(repeat_num)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- block reader repeat tow {} -----------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
|
||||
|
@ -244,10 +357,14 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
|
|||
data_set = data_set.repeat(repeat_num)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[id]: {} ----------------------------".format(item["id"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- block reader repeat tow {} -----------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[id]: {} ----------------------------".format(item["id"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
|
||||
|
@ -256,15 +373,21 @@ def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
|
|||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers)
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + str(x)
|
||||
for x in range(FILES_NUM)], columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
@ -277,11 +400,16 @@ def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() < 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter < 10
|
||||
|
||||
|
@ -324,11 +452,16 @@ def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 30
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 30
|
||||
if os.path.exists(CV1_FILE_NAME):
|
||||
|
@ -346,7 +479,8 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
|
|||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
|
@ -365,11 +499,16 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() < 20
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter < 20
|
||||
for x in paths:
|
||||
|
@ -385,11 +524,16 @@ def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
@ -401,10 +545,14 @@ def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
|
|||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- num_iter: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
||||
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- num_iter: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[id]: {} ------------------------".format(item["id"]))
|
||||
logger.info(
|
||||
"-------------- item[rating]: {} --------------------".format(item["rating"]))
|
||||
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
|
||||
item["input_ids"], item["input_ids"].shape))
|
||||
logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
|
||||
|
@ -445,10 +593,13 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
|
|||
|
||||
# define map operations
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
|
||||
resize_op = vision.Resize(
|
||||
(resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
|
||||
|
||||
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
|
||||
data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)
|
||||
data_set = data_set.map(
|
||||
input_columns=["data"], operations=decode_op, num_parallel_workers=4)
|
||||
data_set = data_set.map(
|
||||
input_columns=["data"], operations=resize_op, num_parallel_workers=4)
|
||||
|
||||
data_set = data_set.batch(2)
|
||||
assert data_set.get_dataset_size() == 5
|
||||
|
@ -468,11 +619,16 @@ def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
@ -486,11 +642,16 @@ def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
|
|||
data_set = data_set.repeat(repeat_num)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- repeat two test {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} ----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ---------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- repeat two test {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} ----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ---------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
|
||||
|
@ -599,7 +760,8 @@ def get_mkv_data(dir_name):
|
|||
"id": index}
|
||||
data_list.append(data_json)
|
||||
index += 1
|
||||
logger.info('{} images are missing'.format(len(file_list) - len(data_list)))
|
||||
logger.info('{} images are missing'.format(
|
||||
len(file_list) - len(data_list)))
|
||||
return data_list
|
||||
|
||||
|
||||
|
@ -686,6 +848,10 @@ def inputs(vectors, maxlen=50):
|
|||
|
||||
def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
if os.path.exists("{}".format(mindrecord_file_name)):
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
if os.path.exists("{}.db".format(mindrecord_file_name)):
|
||||
os.remove("{}.db".format(x))
|
||||
data = [{"file_name": "001.jpg", "label": 4,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
|
@ -782,7 +948,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['file_name'] = np.asarray(
|
||||
list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
|
@ -807,7 +974,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 13
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -815,7 +983,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"],
|
||||
columns_list=["source_sos_ids",
|
||||
"source_sos_mask", "target_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -832,7 +1001,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
|
||||
columns_list=[
|
||||
"image2", "source_sos_mask", "image3", "target_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -841,7 +1011,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 4
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -849,7 +1020,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
|
||||
num_readers = 3
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_sos_ids", "image4", "source_sos_ids"],
|
||||
columns_list=["target_sos_ids",
|
||||
"image4", "source_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -858,7 +1030,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -866,7 +1039,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
|
||||
num_readers = 3
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"],
|
||||
columns_list=["target_sos_ids", "image5",
|
||||
"image4", "image3", "source_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -875,7 +1049,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -883,7 +1058,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"],
|
||||
columns_list=["target_eos_mask", "image5",
|
||||
"image2", "source_sos_mask", "label"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -892,7 +1068,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -910,7 +1087,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
assert len(item) == 11
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -975,7 +1153,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['file_name'] = np.asarray(
|
||||
list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
|
@ -994,7 +1173,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 7
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1011,7 +1191,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1028,7 +1209,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 2
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1045,7 +1227,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 2
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1062,7 +1245,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1070,7 +1254,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image4", "image5", "image2", "image3", "file_name"],
|
||||
columns_list=["image4", "image5",
|
||||
"image2", "image3", "file_name"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
|
@ -1079,7 +1264,8 @@ def test_write_with_multi_bytes_and_MindDataset():
|
|||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1177,7 +1363,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 8
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1196,7 +1383,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 6
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1215,7 +1403,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1234,7 +1423,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1251,7 +1441,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 1
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
@ -1271,7 +1462,8 @@ def test_write_with_multi_array_and_MindDataset():
|
|||
assert len(item) == 8
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
|
|
|
@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
|
|||
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
|
||||
MINDRECORD_FILE = "./cifar100.mindrecord"
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
def test_cifar100_to_mindrecord_without_index_fields():
|
||||
remove_file(MINDRECORD_FILE)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(MINDRECORD_FILE)
|
||||
|
||||
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
|
||||
"""test transform cifar100 dataset to mindrecord without index fields."""
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
|
||||
ret = cifar100_transformer.transform()
|
||||
|
@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
|
|||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
os.remove("{}".format(MINDRECORD_FILE + "_test"))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord():
|
||||
def test_cifar100_to_mindrecord(fixture_file):
|
||||
"""test transform cifar100 dataset to mindrecord."""
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
|
||||
cifar100_transformer.transform(['fine_label', 'coarse_label'])
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
os.remove("{}".format(MINDRECORD_FILE + "_test"))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
|
||||
|
||||
|
||||
def read():
|
||||
|
@ -77,8 +82,7 @@ def read():
|
|||
assert count == 4
|
||||
reader.close()
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord_illegal_file_name():
|
||||
def test_cifar100_to_mindrecord_illegal_file_name(fixture_file):
|
||||
"""
|
||||
test transform cifar100 dataset to mindrecord
|
||||
when file name contains illegal character.
|
||||
|
@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
|
|||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
|
||||
cifar100_transformer.transform()
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord_filename_start_with_space():
|
||||
def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name starts with space.
|
||||
|
@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
|
|||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
|
||||
cifar100_transformer.transform()
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord_filename_contain_space():
|
||||
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name contains space.
|
||||
|
@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
|
|||
cifar100_transformer.transform()
|
||||
assert os.path.exists(filename)
|
||||
assert os.path.exists(filename + "_test")
|
||||
os.remove("{}".format(filename))
|
||||
os.remove("{}.db".format(filename))
|
||||
|
||||
os.remove("{}".format(filename + "_test"))
|
||||
os.remove("{}.db".format(filename + "_test"))
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord_directory():
|
||||
def test_cifar100_to_mindrecord_directory(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when destination path is directory.
|
||||
|
@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
|
|||
CIFAR100_DIR)
|
||||
cifar100_transformer.transform()
|
||||
|
||||
|
||||
def test_cifar100_to_mindrecord_filename_equals_cifar100():
|
||||
def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when destination path equals source path.
|
||||
|
|
|
@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
|
|||
CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
|
||||
MINDRECORD_FILE = "./cifar10.mindrecord"
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
def test_cifar10_to_mindrecord_without_index_fields():
|
||||
remove_file(MINDRECORD_FILE)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(MINDRECORD_FILE)
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_space_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
x = "./yes ok"
|
||||
remove_file(x)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(x)
|
||||
|
||||
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord without index fields."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
cifar10_transformer.transform()
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
os.remove("{}".format(MINDRECORD_FILE + "_test"))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord():
|
||||
|
||||
def test_cifar10_to_mindrecord(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
cifar10_transformer.transform(['label'])
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
os.remove("{}".format(MINDRECORD_FILE + "_test"))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord_with_return():
|
||||
def test_cifar10_to_mindrecord_with_return(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
ret = cifar10_transformer.transform(['label'])
|
||||
|
@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
|
|||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
os.remove("{}".format(MINDRECORD_FILE + "_test"))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
|
||||
|
||||
|
||||
def read():
|
||||
|
@ -90,8 +109,7 @@ def read():
|
|||
assert count == 4
|
||||
reader.close()
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord_illegal_file_name():
|
||||
def test_cifar10_to_mindrecord_illegal_file_name(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name contains illegal character.
|
||||
|
@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
|
|||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
|
||||
cifar10_transformer.transform()
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord_filename_start_with_space():
|
||||
def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name starts with space.
|
||||
|
@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
|
|||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
|
||||
cifar10_transformer.transform()
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord_filename_contain_space():
|
||||
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name contains space.
|
||||
|
@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
|
|||
cifar10_transformer.transform()
|
||||
assert os.path.exists(filename)
|
||||
assert os.path.exists(filename + "_test")
|
||||
os.remove("{}".format(filename))
|
||||
os.remove("{}.db".format(filename))
|
||||
|
||||
os.remove("{}".format(filename + "_test"))
|
||||
os.remove("{}.db".format(filename + "_test"))
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord_directory():
|
||||
def test_cifar10_to_mindrecord_directory(fixture_file):
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when destination path is directory.
|
||||
|
|
|
@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
|
|||
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
|
||||
PARTITION_NUMBER = 4
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
def remove_file():
|
||||
x = MINDRECORD_FILE
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + ".db"
|
||||
remove_one_file(x)
|
||||
for i in range(PARTITION_NUMBER):
|
||||
x = MINDRECORD_FILE + str(i)
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
remove_file()
|
||||
yield "yield_fixture_data"
|
||||
remove_file()
|
||||
|
||||
def read(filename):
|
||||
"""test file reade"""
|
||||
|
@ -38,8 +58,7 @@ def read(filename):
|
|||
assert count == 20
|
||||
reader.close()
|
||||
|
||||
|
||||
def test_imagenet_to_mindrecord():
|
||||
def test_imagenet_to_mindrecord(fixture_file):
|
||||
"""test transform imagenet dataset to mindrecord."""
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
|
||||
MINDRECORD_FILE, PARTITION_NUMBER)
|
||||
|
@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
|
|||
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||
read(MINDRECORD_FILE + "0")
|
||||
for i in range(PARTITION_NUMBER):
|
||||
os.remove(MINDRECORD_FILE + str(i))
|
||||
os.remove(MINDRECORD_FILE + str(i) + ".db")
|
||||
|
||||
|
||||
def test_imagenet_to_mindrecord_default_partition_number():
|
||||
def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when partition number is default.
|
||||
|
@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
|
|||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + ".db")
|
||||
read(MINDRECORD_FILE)
|
||||
os.remove("{}".format(MINDRECORD_FILE))
|
||||
os.remove("{}.db".format(MINDRECORD_FILE))
|
||||
|
||||
|
||||
def test_imagenet_to_mindrecord_partition_number_0():
|
||||
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when partition number is 0.
|
||||
|
@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
|
|||
MINDRECORD_FILE, 0)
|
||||
imagenet_transformer.transform()
|
||||
|
||||
|
||||
def test_imagenet_to_mindrecord_partition_number_none():
|
||||
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when partition number is none.
|
||||
|
@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
|
|||
MINDRECORD_FILE, None)
|
||||
imagenet_transformer.transform()
|
||||
|
||||
|
||||
def test_imagenet_to_mindrecord_illegal_filename():
|
||||
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when file name contains illegal character.
|
||||
|
|
|
@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord"
|
|||
NLP_FILE_NAME = "./aclImdb.mindrecord"
|
||||
FILES_NUM = 4
|
||||
|
||||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
|
||||
def remove_file(file_name):
|
||||
x = file_name
|
||||
remove_one_file(x)
|
||||
x = file_name + ".db"
|
||||
remove_one_file(x)
|
||||
for i in range(FILES_NUM):
|
||||
x = file_name + str(i)
|
||||
remove_one_file(x)
|
||||
x = file_name + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_cv_file():
|
||||
"""add/remove file"""
|
||||
remove_file(CV_FILE_NAME)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(CV_FILE_NAME)
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_nlp_file():
|
||||
"""add/remove file"""
|
||||
remove_file(NLP_FILE_NAME)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(NLP_FILE_NAME)
|
||||
|
||||
def test_cv_file_writer_shard_num_none():
|
||||
"""test cv file writer when shard num is None."""
|
||||
|
@ -83,8 +111,7 @@ def test_lack_partition_and_db():
|
|||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
|
||||
|
||||
def test_lack_db():
|
||||
def test_lack_db(fixture_cv_file):
|
||||
"""test file reader when db file does not exist."""
|
||||
create_cv_mindrecord(1)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -94,10 +121,8 @@ def test_lack_db():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
||||
|
||||
def test_lack_some_partition_and_db():
|
||||
def test_lack_some_partition_and_db(fixture_cv_file):
|
||||
"""test file reader when some partition and db do not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -110,16 +135,8 @@ def test_lack_some_partition_and_db():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_lack_some_partition_first():
|
||||
def test_lack_some_partition_first(fixture_cv_file):
|
||||
"""test file reader when first partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -131,14 +148,8 @@ def test_lack_some_partition_first():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_lack_some_partition_middle():
|
||||
def test_lack_some_partition_middle(fixture_cv_file):
|
||||
"""test file reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -150,14 +161,8 @@ def test_lack_some_partition_middle():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_lack_some_partition_last():
|
||||
def test_lack_some_partition_last(fixture_cv_file):
|
||||
"""test file reader when last partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -169,14 +174,8 @@ def test_lack_some_partition_last():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_mindpage_lack_some_partition():
|
||||
def test_mindpage_lack_some_partition(fixture_cv_file):
|
||||
"""test page reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_lack_some_db():
|
||||
def test_lack_some_db(fixture_cv_file):
|
||||
"""test file reader when some db does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
|
@ -206,11 +199,6 @@ def test_lack_some_db():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_invalid_mindrecord():
|
||||
|
@ -225,8 +213,7 @@ def test_invalid_mindrecord():
|
|||
in str(err.value)
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
||||
|
||||
def test_invalid_db():
|
||||
def test_invalid_db(fixture_cv_file):
|
||||
"""test file reader when the content of db is illegal."""
|
||||
create_cv_mindrecord(1)
|
||||
os.remove("imagenet.mindrecord.db")
|
||||
|
@ -237,11 +224,8 @@ def test_invalid_db():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
os.remove("imagenet.mindrecord")
|
||||
os.remove("imagenet.mindrecord.db")
|
||||
|
||||
|
||||
def test_overwrite_invalid_mindrecord():
|
||||
def test_overwrite_invalid_mindrecord(fixture_cv_file):
|
||||
"""test file writer when overwrite invalid mindreocrd file."""
|
||||
with open(CV_FILE_NAME, 'w') as f:
|
||||
f.write('just for test')
|
||||
|
@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord():
|
|||
assert '[MRMOpenError]: error_code: 1347690596, ' \
|
||||
'error_msg: MindRecord File could not open successfully.' \
|
||||
in str(err.value)
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
||||
|
||||
def test_overwrite_invalid_db():
|
||||
def test_overwrite_invalid_db(fixture_cv_file):
|
||||
"""test file writer when overwrite invalid db file."""
|
||||
with open('imagenet.mindrecord.db', 'w') as f:
|
||||
f.write('just for test')
|
||||
|
@ -261,11 +243,8 @@ def test_overwrite_invalid_db():
|
|||
create_cv_mindrecord(1)
|
||||
assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \
|
||||
'error_msg: Failed to generate index.' in str(err.value)
|
||||
os.remove("imagenet.mindrecord")
|
||||
os.remove("imagenet.mindrecord.db")
|
||||
|
||||
|
||||
def test_read_after_close():
|
||||
def test_read_after_close(fixture_cv_file):
|
||||
"""test file reader when close read."""
|
||||
create_cv_mindrecord(1)
|
||||
reader = FileReader(CV_FILE_NAME)
|
||||
|
@ -275,11 +254,8 @@ def test_read_after_close():
|
|||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 0
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
||||
def test_file_read_after_read():
|
||||
def test_file_read_after_read(fixture_cv_file):
|
||||
"""test file reader when finish read."""
|
||||
create_cv_mindrecord(1)
|
||||
reader = FileReader(CV_FILE_NAME)
|
||||
|
@ -295,8 +271,6 @@ def test_file_read_after_read():
|
|||
cnt = cnt + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert cnt == 0
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
||||
def test_cv_file_writer_shard_num_greater_than_1000():
|
||||
|
@ -312,8 +286,7 @@ def test_add_index_without_add_schema():
|
|||
fw.add_index(["label"])
|
||||
assert 'Failed to get meta info' in str(err.value)
|
||||
|
||||
|
||||
def test_mindpage_pageno_pagesize_not_int():
|
||||
def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
|
||||
"""test page reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
|
@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int():
|
|||
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
|
||||
reader.read_at_page_by_id(99999, 0, 1)
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_mindpage_filename_not_exist():
|
||||
def test_mindpage_filename_not_exist(fixture_cv_file):
|
||||
"""test page reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
|
@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist():
|
|||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""test mnist to mindrecord tool"""
|
||||
import cv2
|
||||
import gzip
|
||||
import pytest
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
@ -27,6 +28,34 @@ PARTITION_NUM = 4
|
|||
IMAGE_SIZE = 28
|
||||
NUM_CHANNELS = 1
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
def remove_file():
|
||||
x = "mnist_train.mindrecord"
|
||||
remove_one_file(x)
|
||||
x = "mnist_train.mindrecord.db"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord.db"
|
||||
remove_one_file(x)
|
||||
for i in range(PARTITION_NUM):
|
||||
x = "mnist_train.mindrecord" + str(i)
|
||||
remove_one_file(x)
|
||||
x = "mnist_train.mindrecord" + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord" + str(i)
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord" + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
remove_file()
|
||||
yield "yield_fixture_data"
|
||||
remove_file()
|
||||
|
||||
def read(train_name, test_name):
|
||||
"""test file reader"""
|
||||
|
@ -51,7 +80,7 @@ def read(train_name, test_name):
|
|||
reader.close()
|
||||
|
||||
|
||||
def test_mnist_to_mindrecord():
|
||||
def test_mnist_to_mindrecord(fixture_file):
|
||||
"""test transform mnist dataset to mindrecord."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
|
||||
mnist_transformer.transform()
|
||||
|
@ -60,13 +89,7 @@ def test_mnist_to_mindrecord():
|
|||
|
||||
read("mnist_train.mindrecord", "mnist_test.mindrecord")
|
||||
|
||||
os.remove("{}".format("mnist_train.mindrecord"))
|
||||
os.remove("{}.db".format("mnist_train.mindrecord"))
|
||||
os.remove("{}".format("mnist_test.mindrecord"))
|
||||
os.remove("{}.db".format("mnist_test.mindrecord"))
|
||||
|
||||
|
||||
def test_mnist_to_mindrecord_compare_data():
|
||||
def test_mnist_to_mindrecord_compare_data(fixture_file):
|
||||
"""test transform mnist dataset to mindrecord and compare data."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
|
||||
mnist_transformer.transform()
|
||||
|
@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data():
|
|||
assert np.array(x['label']) == label
|
||||
reader.close()
|
||||
|
||||
os.remove("{}".format("mnist_train.mindrecord"))
|
||||
os.remove("{}.db".format("mnist_train.mindrecord"))
|
||||
os.remove("{}".format("mnist_test.mindrecord"))
|
||||
os.remove("{}.db".format("mnist_test.mindrecord"))
|
||||
|
||||
|
||||
def test_mnist_to_mindrecord_multi_partition():
|
||||
def test_mnist_to_mindrecord_multi_partition(fixture_file):
|
||||
"""test transform mnist dataset to multiple mindrecord files."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
|
||||
mnist_transformer.transform()
|
||||
|
||||
read("mnist_train.mindrecord0", "mnist_test.mindrecord0")
|
||||
|
||||
for i in range(PARTITION_NUM):
|
||||
os.remove("{}".format("mnist_train.mindrecord" + str(i)))
|
||||
os.remove("{}.db".format("mnist_train.mindrecord" + str(i)))
|
||||
os.remove("{}".format("mnist_test.mindrecord" + str(i)))
|
||||
os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))
|
||||
|
|
Loading…
Reference in New Issue