save op in minddataset
This commit is contained in:
parent
220f090ce5
commit
bc676fe250
|
@ -42,11 +42,17 @@
|
|||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/mindrecord/include/shard_category.h"
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_header.h"
|
||||
#include "minddata/mindrecord/include/shard_index_generator.h"
|
||||
#include "minddata/mindrecord/include/shard_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
#include "minddata/mindrecord/include/shard_writer.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using json = nlohmann::json;
|
||||
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *, std::shared_ptr<DatasetOp> *);
|
||||
|
||||
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
|
||||
|
@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
Status s;
|
||||
auto mr_header = std::make_shared<mindrecord::ShardHeader>();
|
||||
auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
|
||||
std::vector<std::string> blob_fields;
|
||||
uint64_t mr_schema_id = 0;
|
||||
if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter.");
|
||||
}
|
||||
|
||||
TensorRow row;
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map =
|
||||
iterator_->GetColumnNameMap(); // map of column name, id
|
||||
bool first_loop = true; // build schema in first loop
|
||||
do {
|
||||
json row_raw_data;
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
s = iterator_->FetchNextTensorRow(&row);
|
||||
}
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (row.empty()) break;
|
||||
if (first_loop) {
|
||||
json mr_json;
|
||||
std::vector<std::string> index_fields;
|
||||
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
|
||||
mr_writer->SetShardHeader(mr_header);
|
||||
first_loop = false;
|
||||
}
|
||||
// construct data
|
||||
if (!row.empty()) { // write data
|
||||
s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
std::shared_ptr<std::vector<uint8_t>> output_bin_data;
|
||||
mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data);
|
||||
std::map<std::uint64_t, std::vector<json>> raw_data;
|
||||
raw_data.insert(std::pair<uint64_t, std::vector<json>>(mr_schema_id, std::vector<json>{row_raw_data}));
|
||||
std::vector<std::vector<uint8_t>> bin_data;
|
||||
if (nullptr != output_bin_data) {
|
||||
bin_data.emplace_back(*output_bin_data);
|
||||
}
|
||||
mr_writer->WriteRawData(raw_data, bin_data);
|
||||
}
|
||||
} while (!row.empty());
|
||||
mr_writer->Commit();
|
||||
mindrecord::ShardIndexGenerator::finalize(file_names);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
json *row_raw_data,
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
|
||||
if (row_raw_data == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("error: row raw data is NULL.");
|
||||
}
|
||||
if (row_bin_data == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("error: row bin data is NULL.");
|
||||
}
|
||||
if (column_name_id_map.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: column not found");
|
||||
}
|
||||
Status s;
|
||||
for (auto &col : column_name_id_map) {
|
||||
auto idx = col.second;
|
||||
auto column_name = col.first;
|
||||
auto &tensor = row[idx];
|
||||
auto column_type = tensor->type();
|
||||
|
||||
std::unique_ptr<std::vector<uint8_t>> data_ptr;
|
||||
if (column_type == DataType::DE_INT8) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int8_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<uint16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT8) {
|
||||
std::unique_ptr<uint8_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT32) {
|
||||
std::unique_ptr<int32_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT32) {
|
||||
std::unique_ptr<int64_t> data;
|
||||
std::unique_ptr<uint32_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT64) {
|
||||
std::unique_ptr<int64_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT32) {
|
||||
std::unique_ptr<float> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT64) {
|
||||
std::unique_ptr<double> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_STRING) {
|
||||
auto buffer = tensor->GetStringsBuffer();
|
||||
std::string ss(reinterpret_cast<const char *>(buffer)); // assume scalar string tensor
|
||||
(*row_raw_data)[column_name] = std::move(ss);
|
||||
continue;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data_ptr != nullptr) {
|
||||
(*row_bin_data)[column_name] = std::move(data_ptr);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert) {
|
||||
if (nullptr == src) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL.");
|
||||
}
|
||||
*data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
|
||||
if (need_convert) {
|
||||
auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
|
||||
std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
|
||||
auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
|
||||
auto el = std::make_unique<T>();
|
||||
for (uint32_t i = 0; i < num_of_elements; ++i) {
|
||||
*el = *(s_ptr + i);
|
||||
auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
|
||||
for (uint32_t j = 0; j < sizeof(T); ++j) {
|
||||
*((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
|
||||
}
|
||||
if (shape.empty()) {
|
||||
*data = std::make_unique<T>();
|
||||
auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
|
||||
for (uint32_t i = 0; i < sizeof(T); ++i) {
|
||||
*(t_ptr + i) = *((*data_ptr)->begin() + i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
const TensorRow &row, json *schema, std::vector<std::string> *index_fields) {
|
||||
if (schema == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("error: schema is NULL.");
|
||||
}
|
||||
if (index_fields == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("error: index fields is NULL.");
|
||||
}
|
||||
if (column_name_id_map.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: column not found.");
|
||||
}
|
||||
for (auto &col : column_name_id_map) {
|
||||
auto idx = col.second;
|
||||
auto column_name = col.first;
|
||||
auto &tensor = row[idx];
|
||||
auto column_type = tensor->type();
|
||||
auto column_shape = tensor->shape();
|
||||
|
||||
std::string mr_type;
|
||||
auto shapes = column_shape.AsVector();
|
||||
std::vector<int> mr_shape(shapes.begin(), shapes.end());
|
||||
std::string el = column_type.ToString();
|
||||
if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
|
||||
std::string err_msg("Error: can not support data type: " + el);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
mr_type = mindrecord::kTypesMap.at(el);
|
||||
}
|
||||
if (mr_shape.empty()) {
|
||||
if (mr_type == "bytes") { // map to int32 when bytes without shape.
|
||||
mr_type == "int32";
|
||||
}
|
||||
(*schema)[column_name] = {{"type", mr_type}};
|
||||
} else {
|
||||
if (mr_type == "string") { // mindrecord can not support string with shape.
|
||||
std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
if (mr_type == "bytes") { // ignore shape of bytes in minrecord
|
||||
(*schema)[column_name] = {{"type", mr_type}};
|
||||
} else {
|
||||
(*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
|
||||
}
|
||||
}
|
||||
if (mr_type == "bytes" || !mr_shape.empty()) continue;
|
||||
index_fields->emplace_back(column_name); // candidate of index fields
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_API_DE_PIPELINE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
|
@ -33,6 +34,7 @@
|
|||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using json = nlohmann::json;
|
||||
using DsOpPtr = std::shared_ptr<DatasetOp>;
|
||||
|
||||
class CacheClient;
|
||||
|
@ -100,6 +102,8 @@ class DEPipeline {
|
|||
|
||||
Status GetOutputTypes(py::list *output);
|
||||
|
||||
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
|
||||
|
||||
int GetDatasetSize() const;
|
||||
|
||||
int GetBatchSize() const;
|
||||
|
@ -110,6 +114,18 @@ class DEPipeline {
|
|||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
template <typename T, typename S>
|
||||
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert = false);
|
||||
|
||||
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
|
||||
|
||||
Status FetchDataFromTensorRow(const TensorRow &row,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
|
||||
|
||||
Status BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded);
|
||||
|
|
|
@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) {
|
|||
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
||||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount);
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
void bindDatasetOps(py::module *m) {
|
||||
(void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
|
||||
|
|
|
@ -312,6 +312,11 @@ class Tensor {
|
|||
// @return const unsigned char*
|
||||
const unsigned char *GetBuffer() const;
|
||||
|
||||
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
|
||||
// tensor's type is a string, otherwise undefined address would be returned.
|
||||
// @return address of the first string of the tensor.
|
||||
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
|
||||
|
||||
// Getter of the type
|
||||
// @return
|
||||
DataType type() const { return type_; }
|
||||
|
@ -643,11 +648,6 @@ class Tensor {
|
|||
// @return length of the string
|
||||
Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const;
|
||||
|
||||
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
|
||||
// tensor's type is a string, otherwise undefined address would be returned.
|
||||
// @return address of the first string of the tensor.
|
||||
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
|
||||
|
||||
// all access to shape_ should be via shape
|
||||
TensorShape shape_;
|
||||
// data type of tensor
|
||||
|
|
|
@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\n Dataset file : ";
|
||||
out << "\nDataset file : ";
|
||||
for (auto &file : dataset_file_) {
|
||||
out << file << " ";
|
||||
}
|
||||
|
|
|
@ -137,6 +137,10 @@ const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "
|
|||
// number field list
|
||||
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
|
||||
|
||||
const std::unordered_map<std::string, std::string> kTypesMap = {
|
||||
{"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"},
|
||||
{"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"},
|
||||
{"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}};
|
||||
/// \brief split a string using a character
|
||||
/// \param[in] field target string
|
||||
/// \param[in] separator a character for spliting
|
||||
|
|
|
@ -124,6 +124,10 @@ class ShardHeader {
|
|||
|
||||
MSRStatus FileToPages(const std::string dump_file_name);
|
||||
|
||||
static MSRStatus initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
|
||||
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
|
||||
uint64_t &schema_id);
|
||||
|
||||
private:
|
||||
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
|
||||
|
||||
|
|
|
@ -57,6 +57,8 @@ class ShardIndexGenerator {
|
|||
/// \brief create databases for indexes
|
||||
MSRStatus WriteToDatabase();
|
||||
|
||||
static MSRStatus finalize(const std::vector<std::string> file_names);
|
||||
|
||||
private:
|
||||
static int Callback(void *not_used, int argc, char **argv, char **az_col_name);
|
||||
|
||||
|
|
|
@ -108,6 +108,13 @@ class ShardWriter {
|
|||
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
|
||||
bool parallel_writer = false);
|
||||
|
||||
MSRStatus MergeBlobData(const std::vector<string> &blob_fields,
|
||||
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
|
||||
std::shared_ptr<std::vector<uint8_t>> *output);
|
||||
|
||||
static MSRStatus initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
|
||||
const std::vector<std::string> &file_names);
|
||||
|
||||
private:
|
||||
/// \brief write shard header data to disk
|
||||
MSRStatus WriteShardHeader();
|
||||
|
|
|
@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
shard_no = task_++;
|
||||
}
|
||||
}
|
||||
MSRStatus ShardIndexGenerator::finalize(const std::vector<std::string> file_names) {
|
||||
if (file_names.empty()) {
|
||||
MS_LOG(ERROR) << "Mindrecord files is empty.";
|
||||
return FAILED;
|
||||
}
|
||||
ShardIndexGenerator sg{file_names[0]};
|
||||
if (SUCCESS != sg.Build()) {
|
||||
MS_LOG(ERROR) << "Failed to build index generator.";
|
||||
return FAILED;
|
||||
}
|
||||
if (SUCCESS != sg.WriteToDatabase()) {
|
||||
MS_LOG(ERROR) << "Failed to write to database.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
|
|||
*row_count = std::get<2>(v);
|
||||
return SUCCESS;
|
||||
}
|
||||
MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
|
||||
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
|
||||
std::shared_ptr<std::vector<uint8_t>> *output) {
|
||||
if (blob_fields.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (blob_fields.size() == 1) {
|
||||
auto &blob = row_bin_data.at(blob_fields[0]);
|
||||
auto blob_size = blob->size();
|
||||
*output = std::make_shared<std::vector<uint8_t>>(blob_size);
|
||||
std::copy(blob->begin(), blob->end(), (*output)->begin());
|
||||
} else {
|
||||
size_t output_size = 0;
|
||||
for (auto &field : blob_fields) {
|
||||
output_size += row_bin_data.at(field)->size();
|
||||
}
|
||||
output_size += blob_fields.size() * sizeof(uint64_t);
|
||||
*output = std::make_shared<std::vector<uint8_t>>(output_size);
|
||||
std::vector<uint8_t> buf(sizeof(uint64_t), 0);
|
||||
size_t idx = 0;
|
||||
for (auto &field : blob_fields) {
|
||||
auto &blob = row_bin_data.at(field);
|
||||
uint64_t blob_size = blob->size();
|
||||
// big edian
|
||||
for (size_t i = 0; i < buf.size(); ++i) {
|
||||
buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size;
|
||||
blob_size >>= 8u;
|
||||
}
|
||||
std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
|
||||
idx += buf.size();
|
||||
std::copy(blob->begin(), blob->end(), (*output)->begin() + idx);
|
||||
idx += blob->size();
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
|
||||
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
|
||||
|
@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
|
|||
last_blob_page = page.first;
|
||||
}
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
|
||||
const std::vector<std::string> &file_names) {
|
||||
if (nullptr == writer_ptr) {
|
||||
MS_LOG(ERROR) << "ShardWriter pointer is NULL.";
|
||||
return FAILED;
|
||||
}
|
||||
auto res = (*writer_ptr)->Open(file_names, false);
|
||||
if (SUCCESS != res) {
|
||||
MS_LOG(ERROR) << "Failed to open mindrecord files to writer.";
|
||||
return FAILED;
|
||||
}
|
||||
(*writer_ptr)->SetHeaderSize(1 << 24);
|
||||
(*writer_ptr)->SetPageSize(1 << 25);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
|
|||
page_in_handle.close();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
|
||||
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
|
||||
uint64_t &schema_id) {
|
||||
if (nullptr == header_ptr) {
|
||||
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
|
||||
return FAILED;
|
||||
}
|
||||
auto schema_ptr = Schema::Build("mindrecord", schema);
|
||||
if (nullptr == schema_ptr) {
|
||||
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
|
||||
return FAILED;
|
||||
}
|
||||
schema_id = (*header_ptr)->AddSchema(schema_ptr);
|
||||
// create index
|
||||
std::vector<std::pair<uint64_t, std::string>> id_index_fields;
|
||||
if (!index_fields.empty()) {
|
||||
for (auto &el : index_fields) {
|
||||
id_index_fields.emplace_back(schema_id, el);
|
||||
}
|
||||
if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) {
|
||||
MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
|
||||
blob_fields = build_schema_ptr->GetBlobFields();
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,13 +38,13 @@ from mindspore._c_expression import typing
|
|||
|
||||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, check_numpyslicesdataset, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -1044,6 +1044,34 @@ class Dataset:
|
|||
|
||||
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
|
||||
|
||||
@check_save
|
||||
def save(self, file_name, num_files=1, file_type='mindrecord'):
|
||||
"""
|
||||
Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord.
|
||||
|
||||
Note:
|
||||
1. To save the samples in order, should set dataset's shuffle false and num_files 1.
|
||||
2. Before call the function, do not use batch, repeat operator or data augmentation operators
|
||||
with random attribute in map operator.
|
||||
3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and
|
||||
multi-dimensional string.
|
||||
|
||||
Args:
|
||||
file_name (str): Path to dataset file.
|
||||
num_files (int, optional): Number of dataset files.(default=1).
|
||||
file_type (str, optional): dataset format.(default='mindrecord')
|
||||
|
||||
"""
|
||||
|
||||
if num_files == 1:
|
||||
file_names = [file_name]
|
||||
else:
|
||||
suffix = len(str(num_files - 1))
|
||||
file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0'))
|
||||
for x in range(num_files)]
|
||||
|
||||
return SaveOp(self).save(file_names, file_type)
|
||||
|
||||
def create_tuple_iterator(self, columns=None):
|
||||
"""
|
||||
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
|
||||
|
|
|
@ -173,6 +173,7 @@ class Iterator:
|
|||
|
||||
# Convert python node into C node and add to C layer execution tree in postorder traversal.
|
||||
def __convert_node_postorder(self, node):
|
||||
self.check_node_type(node)
|
||||
op_type = self.__get_dataset_type(node)
|
||||
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
|
||||
|
||||
|
@ -224,6 +225,10 @@ class Iterator:
|
|||
self._index += 1
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def get_output_shapes(self):
|
||||
return [t for t in self.depipeline.GetOutputShapes()]
|
||||
|
||||
|
@ -245,11 +250,27 @@ class Iterator:
|
|||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
class SaveOp(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def get_next(self):
|
||||
pass
|
||||
|
||||
def check_node_type(self, node):
|
||||
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
|
||||
logger.warning("Used shuffle, repeat, batch before save operator.")
|
||||
|
||||
def save(self, file_names, file_type):
|
||||
return self.depipeline.SaveDataset(file_names, file_type)
|
||||
|
||||
|
||||
class DictIterator(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -269,6 +290,8 @@ class TupleIterator(Iterator):
|
|||
"""
|
||||
The derived class of Iterator with list type.
|
||||
"""
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __init__(self, dataset, columns=None):
|
||||
if columns is not None:
|
||||
|
|
|
@ -246,7 +246,24 @@ def check_celebadataset(method):
|
|||
|
||||
return new_method
|
||||
|
||||
def check_save(method):
|
||||
"""A wrapper that wrap a parameter checker to the save op."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_files']
|
||||
nreq_param_str = ['file_name', 'file_type']
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
||||
raise ValueError("num_files should between {} and {}.".format(1, 1000))
|
||||
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
||||
if param_dict.get('file_type') != 'mindrecord':
|
||||
raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
def check_minddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
||||
|
||||
|
|
|
@ -0,0 +1,390 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This is the test module for saveOp.
|
||||
"""
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
||||
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
||||
|
||||
FILES_NUM = 1
|
||||
num_readers = 1
|
||||
|
||||
|
||||
@pytest.fixture(name="add_and_remove_cv_file")
|
||||
def fixture_remove():
|
||||
"""add/remove cv file"""
|
||||
if os.path.exists("{}".format(CV_FILE_NAME1)):
|
||||
os.remove("{}".format(CV_FILE_NAME1))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME1)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME1))
|
||||
|
||||
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||
os.remove("{}".format(CV_FILE_NAME2))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||
yield "yield_cv_data"
|
||||
if os.path.exists("{}".format(CV_FILE_NAME1)):
|
||||
os.remove("{}".format(CV_FILE_NAME1))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME1)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME1))
|
||||
|
||||
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||
os.remove("{}".format(CV_FILE_NAME2))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||
|
||||
|
||||
def test_case_00(add_and_remove_cv_file): # only bin data
|
||||
data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
|
||||
{"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
|
||||
{"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
|
||||
{"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
|
||||
{"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8')}]
|
||||
schema = {
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"image3": {"type": "bytes"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"}}
|
||||
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
|
||||
writer.add_schema(schema, "schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
|
||||
d1.save(CV_FILE_NAME2, FILES_NUM)
|
||||
data_value_to_list = []
|
||||
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
|
||||
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
|
||||
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert d2.get_dataset_size() == 5
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
def test_case_01(add_and_remove_cv_file): # only raw data
|
||||
data = [{"file_name": "001.jpg", "label": 43},
|
||||
{"file_name": "002.jpg", "label": 91},
|
||||
{"file_name": "003.jpg", "label": 61},
|
||||
{"file_name": "004.jpg", "label": 29},
|
||||
{"file_name": "005.jpg", "label": 78},
|
||||
{"file_name": "006.jpg", "label": 37}]
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"}
|
||||
}
|
||||
|
||||
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
|
||||
writer.add_schema(schema, "schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
|
||||
d1.save(CV_FILE_NAME2, FILES_NUM)
|
||||
|
||||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert d2.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
logger.info(item)
|
||||
assert len(item) == 2
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_case_02(add_and_remove_cv_file): # muti-bytes
|
||||
data = [{"file_name": "001.jpg", "label": 43,
|
||||
"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12345,
|
||||
"float64": 1987654321.123456785,
|
||||
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "002.jpg", "label": 91,
|
||||
"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12445,
|
||||
"float64": 1987654321.123456786,
|
||||
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "003.jpg", "label": 61,
|
||||
"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12545,
|
||||
"float64": 1987654321.123456787,
|
||||
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "004.jpg", "label": 29,
|
||||
"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12645,
|
||||
"float64": 1987654321.123456788,
|
||||
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image4 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "005.jpg", "label": 78,
|
||||
"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12745,
|
||||
"float64": 1987654321.123456789,
|
||||
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "006.jpg", "label": 37,
|
||||
"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
|
||||
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
|
||||
123414314.2141243, 87.1212122], dtype=np.float64),
|
||||
"float32": 3456.12745,
|
||||
"float64": 1987654321.123456789,
|
||||
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32),
|
||||
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
|
||||
]
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"float32_array": {"type": "float32", "shape": [-1]},
|
||||
"float64_array": {"type": "float64", "shape": [-1]},
|
||||
"float32": {"type": "float32"},
|
||||
"float64": {"type": "float64"},
|
||||
"source_sos_ids": {"type": "int32", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"image3": {"type": "bytes"},
|
||||
"label": {"type": "int32"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"}}
|
||||
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
|
||||
writer.add_schema(schema, "schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
|
||||
d1.save(CV_FILE_NAME2, FILES_NUM)
|
||||
data_value_to_list = []
|
||||
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
|
||||
new_data['float32_array'] = item["float32_array"]
|
||||
new_data['float64_array'] = item["float64_array"]
|
||||
new_data['float32'] = item["float32"]
|
||||
new_data['float64'] = item["float64"]
|
||||
new_data['source_sos_ids'] = item["source_sos_ids"]
|
||||
new_data['source_sos_mask'] = item["source_sos_mask"]
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
|
||||
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
|
||||
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert d2.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
assert len(item) == 13
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
if item[field].dtype == np.float32:
|
||||
assert (item[field] ==
|
||||
np.array(data_value_to_list[num_iter][field], np.float32)).all()
|
||||
else:
|
||||
assert (item[field] ==
|
||||
data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def generator_1d():
|
||||
for i in range(10):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_case_03(add_and_remove_cv_file):
|
||||
|
||||
# apply dataset operations
|
||||
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
|
||||
|
||||
d1.save(CV_FILE_NAME2)
|
||||
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
|
||||
i = 0
|
||||
for item in d2.create_dict_iterator(): # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
|
||||
|
||||
def generator_with_type(t):
|
||||
for i in range(64):
|
||||
yield (np.array([i], dtype=t),)
|
||||
|
||||
|
||||
def type_tester(t):
|
||||
logger.info("Test with Type {}".format(t.__name__))
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False)
|
||||
|
||||
data1 = data1.batch(4)
|
||||
|
||||
data1 = data1.repeat(3)
|
||||
|
||||
data1.save(CV_FILE_NAME2)
|
||||
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
|
||||
i = 0
|
||||
num_repeat = 0
|
||||
for item in d2.create_dict_iterator(): # each data is a dictionary
|
||||
golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
|
||||
logger.info(item)
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 4
|
||||
if i == 64:
|
||||
i = 0
|
||||
num_repeat += 1
|
||||
assert num_repeat == 3
|
||||
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||
os.remove("{}".format(CV_FILE_NAME2))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||
|
||||
|
||||
def test_case_04():
|
||||
# uint8 will drop shape as mindrecord store uint8 as bytes
|
||||
types = [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint16, np.uint32, np.float32, np.float64]
|
||||
|
||||
for t in types:
|
||||
type_tester(t)
|
||||
|
||||
|
||||
def test_case_05(add_and_remove_cv_file):
|
||||
|
||||
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
|
||||
|
||||
with pytest.raises(Exception, match="num_files should between 1 and 1000."):
|
||||
d1.save(CV_FILE_NAME2, 0)
|
||||
|
||||
|
||||
def test_case_06(add_and_remove_cv_file):
|
||||
|
||||
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
|
||||
|
||||
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
||||
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
Loading…
Reference in New Issue