forked from mindspore-Ecosystem/mindspore
!999 [MD] mindrecord support reading file list
Merge pull request !999 from liyong126/mindrecord_file_list
This commit is contained in:
commit
1a98c6b459
|
@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
|
std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
|
||||||
(void)builder->SetDatasetFile(ToString(args["dataset_file"]));
|
bool load_dataset = ToBool(args["load_dataset"]);
|
||||||
|
if (load_dataset == true) {
|
||||||
|
(void)builder->SetDatasetFile({ToString(args["dataset_file"])});
|
||||||
|
} else {
|
||||||
|
(void)builder->SetDatasetFile(ToStringVector(args["dataset_file"]));
|
||||||
|
}
|
||||||
|
(void)builder->SetLoadDataset(load_dataset);
|
||||||
std::vector<std::string> in_col_names;
|
std::vector<std::string> in_col_names;
|
||||||
if (!args["columns_list"].is_none()) {
|
if (!args["columns_list"].is_none()) {
|
||||||
in_col_names = ToStringVector(args["columns_list"]);
|
in_col_names = ToStringVector(args["columns_list"]);
|
||||||
|
|
|
@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
|
||||||
});
|
});
|
||||||
|
|
||||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||||
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
|
.def_static("get_num_rows",
|
||||||
int64_t count = 0;
|
[](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
int64_t count = 0;
|
||||||
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||||
auto create = sampler.attr("_create_for_minddataset");
|
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
auto create = sampler.attr("_create_for_minddataset");
|
||||||
}
|
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
|
}
|
||||||
return count;
|
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
|
||||||
});
|
return count;
|
||||||
|
});
|
||||||
|
|
||||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||||
.def_static("get_num_rows_and_classes",
|
.def_static("get_num_rows_and_classes",
|
||||||
|
|
|
@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
|
||||||
using mindrecord::ShardReader;
|
using mindrecord::ShardReader;
|
||||||
|
|
||||||
// Builder constructor. Creates the builder object.
|
// Builder constructor. Creates the builder object.
|
||||||
MindRecordOp::Builder::Builder() : build_dataset_file_("") {
|
MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
|
||||||
// Some arguments to the MindRecordOp constructor have a default argument that is taken
|
// Some arguments to the MindRecordOp constructor have a default argument that is taken
|
||||||
// from the client config.
|
// from the client config.
|
||||||
// The user may choose to change these values for the construction of the StorageOp by
|
// The user may choose to change these values for the construction of the StorageOp by
|
||||||
|
@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
|
||||||
"Building a MindRecordOp that has not provided a file.");
|
"Building a MindRecordOp that has not provided a file.");
|
||||||
}
|
}
|
||||||
|
|
||||||
new_mind_record_op = std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_rows_per_buffer_,
|
new_mind_record_op = std::make_shared<MindRecordOp>(
|
||||||
build_dataset_file_, build_op_connector_queue_size_,
|
build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_,
|
||||||
build_columns_to_load_, build_operators_, build_block_reader_);
|
build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_);
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(new_mind_record_op->Init());
|
RETURN_IF_NOT_OK(new_mind_record_op->Init());
|
||||||
|
|
||||||
|
@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
|
||||||
Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); }
|
Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); }
|
||||||
|
|
||||||
// Constructor of the MindRecordOp.
|
// Constructor of the MindRecordOp.
|
||||||
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
|
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer,
|
||||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
std::vector<std::string> dataset_file, bool load_dataset, int32_t op_connector_queue_size,
|
||||||
|
const std::vector<std::string> &columns_to_load,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader)
|
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader)
|
||||||
: ParallelOp(num_mind_record_workers, op_connector_queue_size),
|
: ParallelOp(num_mind_record_workers, op_connector_queue_size),
|
||||||
rows_per_buffer_(rows_per_buffer),
|
rows_per_buffer_(rows_per_buffer),
|
||||||
dataset_file_(dataset_file),
|
dataset_file_(dataset_file),
|
||||||
|
load_dataset_(load_dataset),
|
||||||
columns_to_load_(columns_to_load),
|
columns_to_load_(columns_to_load),
|
||||||
operators_(operators),
|
operators_(operators),
|
||||||
num_mind_record_workers_(num_mind_record_workers),
|
num_mind_record_workers_(num_mind_record_workers),
|
||||||
|
@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
|
||||||
// Private helper method to encapsulate some common construction/reset tasks
|
// Private helper method to encapsulate some common construction/reset tasks
|
||||||
Status MindRecordOp::Init() {
|
Status MindRecordOp::Init() {
|
||||||
shard_reader_ = std::make_unique<ShardReader>();
|
shard_reader_ = std::make_unique<ShardReader>();
|
||||||
auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_);
|
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
|
||||||
|
block_reader_);
|
||||||
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED,
|
CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS,
|
||||||
"MindRecordOp init failed. Error message: " + ErrnoToMessage(rc));
|
"MindRecordOp init failed. Error message: " + ErrnoToMessage(rc));
|
||||||
|
|
||||||
data_schema_ = std::make_unique<DataSchema>();
|
data_schema_ = std::make_unique<DataSchema>();
|
||||||
|
@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
|
||||||
// Call the super class for displaying any common detailed info
|
// Call the super class for displaying any common detailed info
|
||||||
ParallelOp::Print(out, show_all);
|
ParallelOp::Print(out, show_all);
|
||||||
// Then show any custom derived-internal stuff
|
// Then show any custom derived-internal stuff
|
||||||
out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_
|
out << "\n Dataset file : ";
|
||||||
<< "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_
|
for (auto &file : dataset_file_) {
|
||||||
|
out << file << " ";
|
||||||
|
}
|
||||||
|
out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_
|
||||||
|
<< "\nNumber of buffers : " << buffers_needed_
|
||||||
<< "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n";
|
<< "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
|
||||||
int64_t *count) {
|
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
|
||||||
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
||||||
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count);
|
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count);
|
||||||
if (rc == MSRStatus::FAILED) {
|
if (rc == MSRStatus::FAILED) {
|
||||||
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
|
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Builder &SetDatasetFile(const std::string &file) {
|
Builder &SetDatasetFile(const std::vector<std::string> &files) {
|
||||||
build_dataset_file_ = file;
|
build_dataset_file_ = files;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Builder &SetLoadDataset(bool load_dataset) {
|
||||||
|
build_load_dataset_ = load_dataset;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
Status SanityCheck() const;
|
Status SanityCheck() const;
|
||||||
|
|
||||||
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
|
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
|
||||||
|
@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
|
||||||
int32_t builder_num_workers_;
|
int32_t builder_num_workers_;
|
||||||
int32_t build_rows_per_buffer_;
|
int32_t build_rows_per_buffer_;
|
||||||
int32_t build_op_connector_queue_size_;
|
int32_t build_op_connector_queue_size_;
|
||||||
std::string build_dataset_file_;
|
std::vector<std::string> build_dataset_file_;
|
||||||
|
bool build_load_dataset_;
|
||||||
std::vector<std::string> build_columns_to_load_;
|
std::vector<std::string> build_columns_to_load_;
|
||||||
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
|
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
|
||||||
bool build_block_reader_;
|
bool build_block_reader_;
|
||||||
|
@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
|
||||||
// @note The builder class should be used to call it
|
// @note The builder class should be used to call it
|
||||||
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
|
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
|
||||||
// @param rows_per_buffer - The requested number of rows per buffer
|
// @param rows_per_buffer - The requested number of rows per buffer
|
||||||
// @param dataset_file - A shard file
|
// @param dataset_file - dataset files
|
||||||
// @param op_connector_queue_size - The output connector queue size
|
// @param op_connector_queue_size - The output connector queue size
|
||||||
// @param columns_to_load - The list of columns to use (column name)
|
// @param columns_to_load - The list of columns to use (column name)
|
||||||
// @param operators - ShardOperators for Shuffle, Category, Sample
|
// @param operators - ShardOperators for Shuffle, Category, Sample
|
||||||
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
|
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector<std::string> dataset_file,
|
||||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
bool load_dataset, int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader);
|
||||||
|
|
||||||
// Destructor
|
// Destructor
|
||||||
|
@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
|
||||||
// Getter method
|
// Getter method
|
||||||
int32_t num_rows() const { return num_rows_; }
|
int32_t num_rows() const { return num_rows_; }
|
||||||
|
|
||||||
// Getter method
|
static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
|
||||||
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||||
int64_t *count);
|
|
||||||
|
|
||||||
// Getter method
|
// Getter method
|
||||||
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||||
|
|
||||||
// Getter method
|
// Getter method
|
||||||
std::string dataset_file() const { return dataset_file_; }
|
std::vector<std::string> dataset_file() const { return dataset_file_; }
|
||||||
|
|
||||||
// Getter method
|
// Getter method
|
||||||
std::vector<std::string> columns_to_load() const { return columns_to_load_; }
|
std::vector<std::string> columns_to_load() const { return columns_to_load_; }
|
||||||
|
|
||||||
bool block_reader() const { return block_reader_; }
|
bool block_reader() const { return block_reader_; }
|
||||||
|
|
||||||
|
bool load_dataset() const { return load_dataset_; }
|
||||||
|
|
||||||
Status Init();
|
Status Init();
|
||||||
|
|
||||||
Status SetColumnsBlob();
|
Status SetColumnsBlob();
|
||||||
|
@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
|
||||||
Status FetchBlockBuffer(const int32_t &buffer_id);
|
Status FetchBlockBuffer(const int32_t &buffer_id);
|
||||||
|
|
||||||
int32_t rows_per_buffer_; // The number of requested rows per buffer.
|
int32_t rows_per_buffer_; // The number of requested rows per buffer.
|
||||||
std::string dataset_file_; // A dataset file
|
std::vector<std::string> dataset_file_; // dataset files
|
||||||
|
bool load_dataset_; // load dataset from single file or not
|
||||||
std::vector<std::string> columns_to_load_; // Columns to load from dataset
|
std::vector<std::string> columns_to_load_; // Columns to load from dataset
|
||||||
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
|
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
|
||||||
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
|
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
|
||||||
|
|
|
@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
|
||||||
case IO_FAILED:
|
case IO_FAILED:
|
||||||
return "io operate failed";
|
return "io operate failed";
|
||||||
break;
|
break;
|
||||||
|
case MATCH_HEADER_FAILED:
|
||||||
|
return "match header failed";
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return "invalid error no";
|
return "invalid error no";
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
|
||||||
void BindShardReader(const py::module *m) {
|
void BindShardReader(const py::module *m) {
|
||||||
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
|
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector<std::string> &,
|
.def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
|
||||||
|
const std::vector<std::string> &,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
||||||
ShardReader::OpenPy)
|
ShardReader::OpenPy)
|
||||||
.def("launch", &ShardReader::Launch)
|
.def("launch", &ShardReader::Launch)
|
||||||
|
@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
|
||||||
void BindShardSegment(py::module *m) {
|
void BindShardSegment(py::module *m) {
|
||||||
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
|
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector<std::string> &,
|
.def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
|
||||||
|
const std::vector<std::string> &,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
||||||
ShardSegment::OpenPy)
|
ShardSegment::OpenPy)
|
||||||
.def("get_category_fields",
|
.def("get_category_fields",
|
||||||
|
|
|
@ -72,7 +72,8 @@ enum MSRStatus {
|
||||||
ILLEGAL_PARAMETERS,
|
ILLEGAL_PARAMETERS,
|
||||||
GET_PAGE_BY_GROUP_ID_FAILED,
|
GET_PAGE_BY_GROUP_ID_FAILED,
|
||||||
GET_SYSTEM_STATE_FAILED,
|
GET_SYSTEM_STATE_FAILED,
|
||||||
IO_FAILED
|
IO_FAILED,
|
||||||
|
MATCH_HEADER_FAILED
|
||||||
};
|
};
|
||||||
|
|
||||||
// convert error no to string message
|
// convert error no to string message
|
||||||
|
|
|
@ -35,10 +35,11 @@ class ShardHeader {
|
||||||
public:
|
public:
|
||||||
ShardHeader();
|
ShardHeader();
|
||||||
|
|
||||||
MSRStatus Build(const std::string &file_path);
|
|
||||||
|
|
||||||
~ShardHeader() = default;
|
~ShardHeader() = default;
|
||||||
|
|
||||||
|
MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
|
||||||
|
|
||||||
|
static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
|
||||||
/// \brief add the schema and save it
|
/// \brief add the schema and save it
|
||||||
/// \param[in] schema the schema needs to be added
|
/// \param[in] schema the schema needs to be added
|
||||||
/// \return the last schema's id
|
/// \return the last schema's id
|
||||||
|
@ -126,7 +127,7 @@ class ShardHeader {
|
||||||
MSRStatus FileToPages(const std::string dump_file_name);
|
MSRStatus FileToPages(const std::string dump_file_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MSRStatus InitializeHeader(const std::vector<json> &headers);
|
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
|
||||||
|
|
||||||
/// \brief get the headers from all the shard data
|
/// \brief get the headers from all the shard data
|
||||||
/// \param[in] the shard data real path
|
/// \param[in] the shard data real path
|
||||||
|
@ -137,9 +138,9 @@ class ShardHeader {
|
||||||
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
|
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
|
||||||
|
|
||||||
/// \brief check the binary file status
|
/// \brief check the binary file status
|
||||||
MSRStatus CheckFileStatus(const std::string &path);
|
static MSRStatus CheckFileStatus(const std::string &path);
|
||||||
|
|
||||||
std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
|
static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
|
||||||
|
|
||||||
void ParseHeader(const json &header);
|
void ParseHeader(const json &header);
|
||||||
|
|
||||||
|
@ -149,7 +150,7 @@ class ShardHeader {
|
||||||
|
|
||||||
MSRStatus CheckIndexField(const std::string &field, const json &schema);
|
MSRStatus CheckIndexField(const std::string &field, const json &schema);
|
||||||
|
|
||||||
void ParsePage(const json &page);
|
void ParsePage(const json &page, int shard_index, bool load_dataset);
|
||||||
|
|
||||||
MSRStatus ParseStatistics(const json &statistics);
|
MSRStatus ParseStatistics(const json &statistics);
|
||||||
|
|
||||||
|
|
|
@ -68,23 +68,25 @@ class ShardReader {
|
||||||
virtual ~ShardReader();
|
virtual ~ShardReader();
|
||||||
|
|
||||||
/// \brief open files and initialize reader, c++ API
|
/// \brief open files and initialize reader, c++ API
|
||||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||||
|
/// \param[in] load_dataset load dataset from single file or not
|
||||||
/// \param[in] n_consumer number of threads when reading
|
/// \param[in] n_consumer number of threads when reading
|
||||||
/// \param[in] selected_columns column list to be populated
|
/// \param[in] selected_columns column list to be populated
|
||||||
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
||||||
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
|
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
|
||||||
/// \return MSRStatus the status of MSRStatus
|
/// \return MSRStatus the status of MSRStatus
|
||||||
MSRStatus Open(const std::string &file_path, int n_consumer = 4,
|
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
|
||||||
const std::vector<std::string> &selected_columns = {},
|
const std::vector<std::string> &selected_columns = {},
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false);
|
||||||
|
|
||||||
/// \brief open files and initialize reader, python API
|
/// \brief open files and initialize reader, python API
|
||||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||||
|
/// \param[in] load_dataset load dataset from single file or not
|
||||||
/// \param[in] n_consumer number of threads when reading
|
/// \param[in] n_consumer number of threads when reading
|
||||||
/// \param[in] selected_columns column list to be populated
|
/// \param[in] selected_columns column list to be populated
|
||||||
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
||||||
/// \return MSRStatus the status of MSRStatus
|
/// \return MSRStatus the status of MSRStatus
|
||||||
MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4,
|
MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
|
||||||
const std::vector<std::string> &selected_columns = {},
|
const std::vector<std::string> &selected_columns = {},
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
|
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
|
||||||
|
|
||||||
|
@ -114,11 +116,13 @@ class ShardReader {
|
||||||
int GetShardCount() const;
|
int GetShardCount() const;
|
||||||
|
|
||||||
/// \brief get the number of rows in database
|
/// \brief get the number of rows in database
|
||||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||||
|
/// \param[in] load_dataset load dataset from single file or not
|
||||||
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
|
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
|
||||||
/// \param[out] count # of rows
|
/// \param[out] count # of rows
|
||||||
/// \return MSRStatus the status of MSRStatus
|
/// \return MSRStatus the status of MSRStatus
|
||||||
MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||||
|
const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||||
|
|
||||||
/// \brief shuffle task with incremental seed
|
/// \brief shuffle task with incremental seed
|
||||||
/// \return void
|
/// \return void
|
||||||
|
@ -220,7 +224,7 @@ class ShardReader {
|
||||||
std::vector<std::vector<json>> &column_values);
|
std::vector<std::vector<json>> &column_values);
|
||||||
|
|
||||||
/// \brief initialize reader
|
/// \brief initialize reader
|
||||||
MSRStatus Init(const std::string &file_path);
|
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
|
||||||
|
|
||||||
/// \brief validate column list
|
/// \brief validate column list
|
||||||
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
|
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
|
||||||
|
@ -292,8 +296,9 @@ class ShardReader {
|
||||||
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
||||||
|
|
||||||
/// \brief get number of classes
|
/// \brief get number of classes
|
||||||
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
|
int64_t GetNumClasses(const std::string &category_field);
|
||||||
|
|
||||||
|
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
|
||||||
/// \brief get exactly blob fields data by indices
|
/// \brief get exactly blob fields data by indices
|
||||||
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
|
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
|
||||||
std::vector<uint32_t> &ordered_selected_columns_index);
|
std::vector<uint32_t> &ordered_selected_columns_index);
|
||||||
|
|
|
@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
|
||||||
write_success_(true) {}
|
write_success_(true) {}
|
||||||
|
|
||||||
MSRStatus ShardIndexGenerator::Build() {
|
MSRStatus ShardIndexGenerator::Build() {
|
||||||
|
auto ret = ShardHeader::BuildSingleHeader(file_path_);
|
||||||
|
if (ret.first != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto json_header = ret.second;
|
||||||
|
|
||||||
|
auto ret2 = GetParentDir(file_path_);
|
||||||
|
if (SUCCESS != ret2.first) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<std::string> real_addresses;
|
||||||
|
for (const auto &path : json_header["shard_addresses"]) {
|
||||||
|
std::string abs_path = ret2.second + string(path);
|
||||||
|
real_addresses.emplace_back(abs_path);
|
||||||
|
}
|
||||||
ShardHeader header = ShardHeader();
|
ShardHeader header = ShardHeader();
|
||||||
if (header.Build(file_path_) != SUCCESS) {
|
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||||
MS_LOG(ERROR) << "Build shard schema failed.";
|
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
shard_header_ = header;
|
shard_header_ = header;
|
||||||
|
|
|
@ -47,20 +47,55 @@ ShardReader::ShardReader() {
|
||||||
block_reader_ = false;
|
block_reader_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::Init(const std::string &file_path) {
|
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
||||||
if (!IsLegalFile(file_path)) {
|
if (!IsLegalFile(file_path)) {
|
||||||
return FAILED;
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
ShardHeader sh = ShardHeader();
|
auto ret = ShardHeader::BuildSingleHeader(file_path);
|
||||||
if (sh.Build(file_path) == FAILED) {
|
if (ret.first != SUCCESS) {
|
||||||
return FAILED;
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
auto header = ret.second;
|
||||||
header_size_ = shard_header_->GetHeaderSize();
|
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||||
page_size_ = shard_header_->GetPageSize();
|
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
||||||
file_paths_ = shard_header_->GetShardAddresses();
|
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
||||||
|
return {SUCCESS, header["shard_addresses"]};
|
||||||
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||||
|
std::string file_path = file_paths[0];
|
||||||
|
json first_meta_data = json();
|
||||||
|
auto ret = GetMeta(file_path, first_meta_data);
|
||||||
|
if (ret.first != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (file_paths.size() == 1 && load_dataset == true) {
|
||||||
|
auto ret2 = GetParentDir(file_path);
|
||||||
|
if (SUCCESS != ret2.first) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<std::string> real_addresses;
|
||||||
|
for (const auto &path : ret.second) {
|
||||||
|
std::string abs_path = ret2.second + string(path);
|
||||||
|
real_addresses.emplace_back(abs_path);
|
||||||
|
}
|
||||||
|
file_paths_ = real_addresses;
|
||||||
|
} else if (file_paths.size() >= 1 && load_dataset == false) {
|
||||||
|
file_paths_ = file_paths;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Error in parameter file_path or load_dataset.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
for (const auto &file : file_paths_) {
|
for (const auto &file : file_paths_) {
|
||||||
|
json meta_data = json();
|
||||||
|
auto ret1 = GetMeta(file, meta_data);
|
||||||
|
if (ret1.first != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (meta_data != first_meta_data) {
|
||||||
|
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
sqlite3 *db = nullptr;
|
sqlite3 *db = nullptr;
|
||||||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||||
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||||
|
@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
|
||||||
}
|
}
|
||||||
database_paths_.push_back(db);
|
database_paths_.push_back(db);
|
||||||
}
|
}
|
||||||
|
ShardHeader sh = ShardHeader();
|
||||||
|
if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||||
|
header_size_ = shard_header_->GetHeaderSize();
|
||||||
|
page_size_ = shard_header_->GetPageSize();
|
||||||
num_rows_ = 0;
|
num_rows_ = 0;
|
||||||
auto row_group_summary = ReadRowGroupSummary();
|
auto row_group_summary = ReadRowGroupSummary();
|
||||||
for (const auto &rg : row_group_summary) {
|
for (const auto &rg : row_group_summary) {
|
||||||
|
@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
||||||
fs->close();
|
fs->close();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
json label_json = json::from_msgpack(label_raw);
|
json label_json = json::from_msgpack(label_raw);
|
||||||
json tmp;
|
json tmp;
|
||||||
if (!columns.empty()) {
|
if (!columns.empty()) {
|
||||||
|
@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) {
|
int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
||||||
ShardHeader sh = ShardHeader();
|
auto shard_count = file_paths_.size();
|
||||||
if (sh.Build(file_path) == FAILED) {
|
auto index_fields = shard_header_->GetFields();
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
auto header = std::make_shared<ShardHeader>(sh);
|
|
||||||
auto file_paths = header->GetShardAddresses();
|
|
||||||
auto shard_count = file_paths.size();
|
|
||||||
auto index_fields = header->GetFields();
|
|
||||||
|
|
||||||
std::map<std::string, int64_t> map_schema_id_fields;
|
std::map<std::string, int64_t> map_schema_id_fields;
|
||||||
for (auto &field : index_fields) {
|
for (auto &field : index_fields) {
|
||||||
|
@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
||||||
std::set<std::string> categories;
|
std::set<std::string> categories;
|
||||||
for (int x = 0; x < shard_count; x++) {
|
for (int x = 0; x < shard_count; x++) {
|
||||||
sqlite3 *db = nullptr;
|
sqlite3 *db = nullptr;
|
||||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||||
if (SQLITE_OK != rc) {
|
if (SQLITE_OK != rc) {
|
||||||
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
|
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
||||||
return categories.size();
|
return categories.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op,
|
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||||
int64_t *count) {
|
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
|
||||||
if (Init(file_path) == FAILED) {
|
if (SUCCESS != Init(file_paths, load_dataset)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
int64_t num_samples = num_rows_;
|
int64_t num_samples = num_rows_;
|
||||||
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||||
std::string category_field = category_op->GetCategoryField();
|
std::string category_field = category_op->GetCategoryField();
|
||||||
auto num_classes = GetNumClasses(file_path, category_field);
|
auto num_classes = GetNumClasses(category_field);
|
||||||
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
|
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
|
||||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||||
num_samples = op->GetNumSamples(num_rows_, 0);
|
num_samples = op->GetNumSamples(num_rows_, 0);
|
||||||
|
@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
|
||||||
const std::vector<std::string> &selected_columns,
|
const std::vector<std::string> &selected_columns,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) {
|
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) {
|
||||||
// Open file and set header by ShardReader
|
// Open file and set header by ShardReader
|
||||||
if (Init(file_path) == FAILED) {
|
auto ret = Init(file_paths, load_dataset);
|
||||||
return FAILED;
|
if (SUCCESS != ret) {
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
auto thread_limit = GetMaxThreadNum();
|
auto thread_limit = GetMaxThreadNum();
|
||||||
if (n_consumer > thread_limit) {
|
if (n_consumer > thread_limit) {
|
||||||
|
@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer,
|
MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
|
||||||
const std::vector<std::string> &selected_columns,
|
const std::vector<std::string> &selected_columns,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||||
// Open file and set header by ShardReader
|
// Open file and set header by ShardReader
|
||||||
if (Init(file_path) == FAILED) {
|
if (SUCCESS != Init(file_paths, load_dataset)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
// should remove blob field from selected_columns when call from python
|
// should remove blob field from selected_columns when call from python
|
||||||
|
|
|
@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
||||||
if (!IsLegalFile(path)) {
|
if (!IsLegalFile(path)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
ShardHeader sh = ShardHeader();
|
auto ret1 = ShardHeader::BuildSingleHeader(path);
|
||||||
if (sh.Build(path) == FAILED) {
|
if (ret1.first != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
auto json_header = ret1.second;
|
||||||
auto paths = shard_header_->GetShardAddresses();
|
auto ret2 = GetParentDir(path);
|
||||||
|
if (SUCCESS != ret2.first) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<std::string> real_addresses;
|
||||||
|
for (const auto &path : json_header["shard_addresses"]) {
|
||||||
|
std::string abs_path = ret2.second + string(path);
|
||||||
|
real_addresses.emplace_back(abs_path);
|
||||||
|
}
|
||||||
|
ShardHeader header = ShardHeader();
|
||||||
|
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
shard_header_ = std::make_shared<ShardHeader>(header);
|
||||||
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
|
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
|
||||||
if (ret == FAILED) {
|
if (ret == FAILED) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
||||||
if (ret == FAILED) {
|
if (ret == FAILED) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
ret = Open(paths, true);
|
ret = Open(json_header["shard_addresses"], true);
|
||||||
if (ret == FAILED) {
|
if (ret == FAILED) {
|
||||||
MS_LOG(ERROR) << "Open file failed";
|
MS_LOG(ERROR) << "Open file failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
|
|
@ -35,8 +35,9 @@ namespace mindrecord {
|
||||||
std::atomic<bool> thread_status(false);
|
std::atomic<bool> thread_status(false);
|
||||||
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
|
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
|
||||||
|
|
||||||
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
|
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
|
||||||
shard_count_ = headers.size();
|
shard_count_ = headers.size();
|
||||||
|
int shard_index = 0;
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (const auto &header : headers) {
|
for (const auto &header : headers) {
|
||||||
if (first) {
|
if (first) {
|
||||||
|
@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
|
||||||
header_size_ = header["header_size"].get<uint64_t>();
|
header_size_ = header["header_size"].get<uint64_t>();
|
||||||
page_size_ = header["page_size"].get<uint64_t>();
|
page_size_ = header["page_size"].get<uint64_t>();
|
||||||
}
|
}
|
||||||
ParsePage(header["page"]);
|
ParsePage(header["page"], shard_index, load_dataset);
|
||||||
|
shard_index++;
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
|
||||||
return {SUCCESS, json_header};
|
return {SUCCESS, json_header};
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardHeader::Build(const std::string &file_path) {
|
std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
|
||||||
auto ret = ValidateHeader(file_path);
|
auto ret = ValidateHeader(file_path);
|
||||||
if (SUCCESS != ret.first) {
|
if (SUCCESS != ret.first) {
|
||||||
return FAILED;
|
return {FAILED, json()};
|
||||||
}
|
}
|
||||||
json main_header = ret.second;
|
json raw_header = ret.second;
|
||||||
json addresses = main_header["shard_addresses"];
|
json header = {{"shard_addresses", raw_header["shard_addresses"]},
|
||||||
vector<string> real_addresses;
|
{"header_size", raw_header["header_size"]},
|
||||||
auto ret1 = GetParentDir(file_path);
|
{"page_size", raw_header["page_size"]},
|
||||||
if (SUCCESS != ret1.first) {
|
{"index_fields", raw_header["index_fields"]},
|
||||||
return FAILED;
|
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
|
||||||
}
|
{"schema", raw_header["schema"][0]["schema"]},
|
||||||
std::string parent_dir = ret1.second;
|
{"version", raw_header["version"]}};
|
||||||
|
return {SUCCESS, header};
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto &addr : addresses) {
|
MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||||
std::string absolute_path = parent_dir + string(addr);
|
|
||||||
real_addresses.emplace_back(absolute_path);
|
|
||||||
}
|
|
||||||
uint32_t thread_num = std::thread::hardware_concurrency();
|
uint32_t thread_num = std::thread::hardware_concurrency();
|
||||||
if (thread_num == 0) thread_num = kThreadNumber;
|
if (thread_num == 0) thread_num = kThreadNumber;
|
||||||
uint32_t work_thread_num = 0;
|
uint32_t work_thread_num = 0;
|
||||||
uint32_t addr_count = real_addresses.size();
|
uint32_t shard_count = file_paths.size();
|
||||||
int group_num = ceil(addr_count * 1.0 / thread_num);
|
int group_num = ceil(shard_count * 1.0 / thread_num);
|
||||||
std::vector<std::thread> thread_set(thread_num);
|
std::vector<std::thread> thread_set(thread_num);
|
||||||
std::vector<json> headers(addr_count);
|
std::vector<json> headers(shard_count);
|
||||||
for (uint32_t x = 0; x < thread_num; ++x) {
|
for (uint32_t x = 0; x < thread_num; ++x) {
|
||||||
int start_num = x * group_num;
|
int start_num = x * group_num;
|
||||||
int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num;
|
int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
|
||||||
if (start_num >= end_num) {
|
if (start_num >= end_num) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_set[x] =
|
thread_set[x] =
|
||||||
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses);
|
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
|
||||||
work_thread_num++;
|
work_thread_num++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
|
||||||
thread_status = false;
|
thread_status = false;
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (SUCCESS != InitializeHeader(headers)) {
|
if (SUCCESS != InitializeHeader(headers, load_dataset)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShardHeader::ParsePage(const json &pages) {
|
void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
|
||||||
|
// set shard_index when load_dataset is false
|
||||||
if (pages_.empty() && shard_count_ <= kMaxShardCount) {
|
if (pages_.empty() && shard_count_ <= kMaxShardCount) {
|
||||||
pages_.resize(shard_count_);
|
pages_.resize(shard_count_);
|
||||||
}
|
}
|
||||||
|
@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
|
||||||
|
|
||||||
std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
|
std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
|
||||||
end_row_id, row_group_ids, page_size);
|
end_row_id, row_group_ids, page_size);
|
||||||
pages_[shard_id].push_back(std::move(parsed_page));
|
if (load_dataset == true) {
|
||||||
|
pages_[shard_id].push_back(std::move(parsed_page));
|
||||||
|
} else {
|
||||||
|
pages_[shard_index].push_back(std::move(parsed_page));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
|
||||||
|
|
||||||
std::string line;
|
std::string line;
|
||||||
while (std::getline(page_in_handle, line)) {
|
while (std::getline(page_in_handle, line)) {
|
||||||
ParsePage(json::parse(line));
|
ParsePage(json::parse(line), -1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
page_in_handle.close();
|
page_in_handle.close();
|
||||||
|
|
|
@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
|
||||||
A source dataset that reads from shard files and database.
|
A source dataset that reads from shard files and database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_file (str): one of file names in dataset.
|
dataset_file (str, list[str]): One of file names or file list in dataset.
|
||||||
columns_list (list[str], optional): List of columns to be read (default=None).
|
columns_list (list[str], optional): List of columns to be read (default=None).
|
||||||
num_parallel_workers (int, optional): The number of readers (default=None).
|
num_parallel_workers (int, optional): The number of readers (default=None).
|
||||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||||
|
@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
|
||||||
shuffle=None, num_shards=None, shard_id=None,
|
shuffle=None, num_shards=None, shard_id=None,
|
||||||
block_reader=False, sampler=None):
|
block_reader=False, sampler=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
|
if isinstance(dataset_file, list):
|
||||||
|
self.load_dataset = False
|
||||||
|
else:
|
||||||
|
self.load_dataset = True
|
||||||
self.dataset_file = dataset_file
|
self.dataset_file = dataset_file
|
||||||
self.columns_list = columns_list
|
self.columns_list = columns_list
|
||||||
self.global_shuffle = shuffle
|
self.global_shuffle = shuffle
|
||||||
|
@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
args["dataset_file"] = self.dataset_file
|
args["dataset_file"] = self.dataset_file
|
||||||
|
args["load_dataset"] = self.load_dataset
|
||||||
args["columns_list"] = self.columns_list
|
args["columns_list"] = self.columns_list
|
||||||
args["global_shuffle"] = self.global_shuffle
|
args["global_shuffle"] = self.global_shuffle
|
||||||
args["partitions"] = self.partitions
|
args["partitions"] = self.partitions
|
||||||
|
@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
|
||||||
Return:
|
Return:
|
||||||
Number, number of batches.
|
Number, number of batches.
|
||||||
"""
|
"""
|
||||||
|
if self.load_dataset:
|
||||||
num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler)
|
dataset_file = [self.dataset_file]
|
||||||
|
else:
|
||||||
|
dataset_file = self.dataset_file
|
||||||
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler)
|
||||||
if self.partitions is not None and self.partitions[0] > 0:
|
if self.partitions is not None and self.partitions[0] > 0:
|
||||||
if num_rows % self.partitions[0] == 0:
|
if num_rows % self.partitions[0] == 0:
|
||||||
num_rows = num_rows // self.partitions[0]
|
num_rows = num_rows // self.partitions[0]
|
||||||
|
|
|
@ -529,8 +529,11 @@ def check_minddataset(method):
|
||||||
dataset_file = param_dict.get('dataset_file')
|
dataset_file = param_dict.get('dataset_file')
|
||||||
if dataset_file is None:
|
if dataset_file is None:
|
||||||
raise ValueError("dataset_file is not provided.")
|
raise ValueError("dataset_file is not provided.")
|
||||||
check_dataset_file(dataset_file)
|
if isinstance(dataset_file, list):
|
||||||
|
for f in dataset_file:
|
||||||
|
check_dataset_file(f)
|
||||||
|
else:
|
||||||
|
check_dataset_file(dataset_file)
|
||||||
check_param_type(nreq_param_int, param_dict, int)
|
check_param_type(nreq_param_int, param_dict, int)
|
||||||
|
|
||||||
check_param_type(nreq_param_list, param_dict, list)
|
check_param_type(nreq_param_list, param_dict, list)
|
||||||
|
|
|
@ -28,7 +28,7 @@ class FileReader:
|
||||||
Class to read MindRecord File series.
|
Class to read MindRecord File series.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): File name of MindRecord File.
|
file_name (str, list[str]): One of MindRecord File or file list.
|
||||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
||||||
It should not be smaller than 1 or larger than the number of CPU.
|
It should not be smaller than 1 or larger than the number of CPU.
|
||||||
columns (list[str], optional): List of fields which correspond data would be read (default=None).
|
columns (list[str], optional): List of fields which correspond data would be read (default=None).
|
||||||
|
@ -38,8 +38,11 @@ class FileReader:
|
||||||
ParamValueError: If file_name, num_consumer or columns is invalid.
|
ParamValueError: If file_name, num_consumer or columns is invalid.
|
||||||
"""
|
"""
|
||||||
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
|
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
|
||||||
check_filename(file_name)
|
if isinstance(file_name, list):
|
||||||
self._file_name = file_name
|
for f in file_name:
|
||||||
|
check_filename(f)
|
||||||
|
else:
|
||||||
|
check_filename(file_name)
|
||||||
|
|
||||||
if num_consumer is not None:
|
if num_consumer is not None:
|
||||||
if isinstance(num_consumer, int):
|
if isinstance(num_consumer, int):
|
||||||
|
|
|
@ -28,7 +28,7 @@ class MindPage:
|
||||||
Class to read MindRecord File series in pagination.
|
Class to read MindRecord File series in pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): File name of MindRecord File.
|
file_name (str): One of MindRecord File or file list.
|
||||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
||||||
It should not be smaller than 1 or larger than the number of CPU.
|
It should not be smaller than 1 or larger than the number of CPU.
|
||||||
|
|
||||||
|
@ -37,8 +37,11 @@ class MindPage:
|
||||||
MRMInitSegmentError: If failed to initialize ShardSegment.
|
MRMInitSegmentError: If failed to initialize ShardSegment.
|
||||||
"""
|
"""
|
||||||
def __init__(self, file_name, num_consumer=4):
|
def __init__(self, file_name, num_consumer=4):
|
||||||
check_filename(file_name)
|
if isinstance(file_name, list):
|
||||||
self._file_name = file_name
|
for f in file_name:
|
||||||
|
check_filename(f)
|
||||||
|
else:
|
||||||
|
check_filename(file_name)
|
||||||
|
|
||||||
if num_consumer is not None:
|
if num_consumer is not None:
|
||||||
if isinstance(num_consumer, int):
|
if isinstance(num_consumer, int):
|
||||||
|
|
|
@ -35,7 +35,7 @@ class ShardReader:
|
||||||
Open file and prepare to read MindRecord File.
|
Open file and prepare to read MindRecord File.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): File name of MindRecord File.
|
file_name (str, list[str]): File names of MindRecord File.
|
||||||
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
||||||
columns (list[str]): List of fields which correspond data would be read.
|
columns (list[str]): List of fields which correspond data would be read.
|
||||||
operator(int): Reserved parameter for operators. Default: None.
|
operator(int): Reserved parameter for operators. Default: None.
|
||||||
|
@ -48,7 +48,12 @@ class ShardReader:
|
||||||
"""
|
"""
|
||||||
columns = columns if columns else []
|
columns = columns if columns else []
|
||||||
operator = operator if operator else []
|
operator = operator if operator else []
|
||||||
ret = self._reader.open(file_name, num_consumer, columns, operator)
|
if isinstance(file_name, list):
|
||||||
|
load_dataset = False
|
||||||
|
else:
|
||||||
|
load_dataset = True
|
||||||
|
file_name = [file_name]
|
||||||
|
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
|
||||||
if ret != ms.MSRStatus.SUCCESS:
|
if ret != ms.MSRStatus.SUCCESS:
|
||||||
logger.error("Failed to open {}.".format(file_name))
|
logger.error("Failed to open {}.".format(file_name))
|
||||||
raise MRMOpenError
|
raise MRMOpenError
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ShardSegment:
|
||||||
Initialize the ShardSegment.
|
Initialize the ShardSegment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): File name of MindRecord File.
|
file_name (str, list[str]): File names of MindRecord File.
|
||||||
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
||||||
columns (list[str]): List of fields which correspond data would be read.
|
columns (list[str]): List of fields which correspond data would be read.
|
||||||
operator(int): Reserved parameter for operators. Default: None.
|
operator(int): Reserved parameter for operators. Default: None.
|
||||||
|
@ -53,7 +53,12 @@ class ShardSegment:
|
||||||
"""
|
"""
|
||||||
self._columns = columns if columns else []
|
self._columns = columns if columns else []
|
||||||
operator = operator if operator else []
|
operator = operator if operator else []
|
||||||
ret = self._segment.open(file_name, num_consumer, self._columns, operator)
|
if isinstance(file_name, list):
|
||||||
|
load_dataset = False
|
||||||
|
else:
|
||||||
|
load_dataset = True
|
||||||
|
file_name = [file_name]
|
||||||
|
ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
|
||||||
if ret != SUCCESS:
|
if ret != SUCCESS:
|
||||||
logger.error("Failed to open {}.".format(file_name))
|
logger.error("Failed to open {}.".format(file_name))
|
||||||
raise MRMOpenError
|
raise MRMOpenError
|
||||||
|
|
|
@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list);
|
.SetColumnsToLoad(column_list);
|
||||||
|
@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list)
|
.SetColumnsToLoad(column_list)
|
||||||
|
@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list)
|
.SetColumnsToLoad(column_list)
|
||||||
|
@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list)
|
.SetColumnsToLoad(column_list)
|
||||||
|
@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list);
|
.SetColumnsToLoad(column_list);
|
||||||
|
@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetBlockReader()
|
.SetBlockReader()
|
||||||
|
@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
|
||||||
|
|
||||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||||
MindRecordOp::Builder builder;
|
MindRecordOp::Builder builder;
|
||||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||||
|
.SetLoadDataset(true)
|
||||||
.SetRowsPerBuffer(3)
|
.SetRowsPerBuffer(3)
|
||||||
.SetNumMindRecordWorkers(4)
|
.SetNumMindRecordWorkers(4)
|
||||||
.SetColumnsToLoad(column_list);
|
.SetColumnsToLoad(column_list);
|
||||||
|
|
|
@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
|
||||||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
ops.push_back(std::make_shared<ShardSample>(kSampleCount));
|
ops.push_back(std::make_shared<ShardSample>(kSampleCount));
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
||||||
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
|
||||||
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
|
||||||
ASSERT_TRUE(partitions.second == 2);
|
ASSERT_TRUE(partitions.second == 2);
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
|
||||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
|
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name},true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
||||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
|
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name},true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
|
||||||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
|
||||||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 16, column_list, ops);
|
dataset.Open({file_name}, true, 16, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
||||||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
|
||||||
ops.push_back(std::make_shared<ShardSample>(kSampleSize));
|
ops.push_back(std::make_shared<ShardSample>(kSampleSize));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
|
||||||
ops.push_back(std::make_shared<ShardSample>(35));
|
ops.push_back(std::make_shared<ShardSample>(35));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
|
||||||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
ShardReader compare_dataset;
|
ShardReader compare_dataset;
|
||||||
compare_dataset.Open(file_name, 4, column_list);
|
compare_dataset.Open({file_name},true, 4, column_list);
|
||||||
compare_dataset.Launch();
|
compare_dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
|
||||||
ops.push_back(std::make_shared<ShardShuffle>(21));
|
ops.push_back(std::make_shared<ShardShuffle>(21));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
||||||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
|
||||||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name},true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
|
||||||
ops.push_back(std::make_shared<ShardShuffle>(100));
|
ops.push_back(std::make_shared<ShardShuffle>(100));
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
|
|
@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
|
||||||
auto column_list = std::vector<std::string>{"file_name"};
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list);
|
dataset.Open({file_name}, true, 4, column_list);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
|
||||||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
ops.push_back(std::make_shared<ShardSample>(17));
|
ops.push_back(std::make_shared<ShardSample>(17));
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
|
||||||
ops.push_back(std::make_shared<ShardSample>(3));
|
ops.push_back(std::make_shared<ShardSample>(3));
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
const bool kBlockReader = true;
|
const bool kBlockReader = true;
|
||||||
dataset.Open(file_name, 4, column_list, ops, kBlockReader);
|
dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
|
||||||
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name);
|
dataset.Open({file_name}, true);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
auto column_list = std::vector<std::string>{"label"};
|
auto column_list = std::vector<std::string>{"label"};
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
auto column_list = std::vector<std::string>{"file_namex"};
|
auto column_list = std::vector<std::string>{"file_namex"};
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
|
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
|
||||||
MS_LOG(INFO) << FormatInfo("Test shard version");
|
MS_LOG(INFO) << FormatInfo("Test shard version");
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(file_name, 4);
|
MSRStatus ret = dataset.Open({file_name}, true, 4);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
|
||||||
auto column_list = std::vector<std::string>{"file_name"};
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, FAILED);
|
ASSERT_EQ(ret, FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
|
||||||
auto column_list = std::vector<std::string>{"file_name"};
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
dataset.Open(file_name, -481565535, column_list);
|
dataset.Open({file_name}, true, -481565535, column_list);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
|
|
@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
|
||||||
ShardSegment dataset;
|
ShardSegment dataset;
|
||||||
dataset.Open(file_name, 4);
|
dataset.Open({file_name}, true, 4);
|
||||||
|
|
||||||
auto x = dataset.GetCategoryFields();
|
auto x = dataset.GetCategoryFields();
|
||||||
for (const auto &fields : x.second) {
|
for (const auto &fields : x.second) {
|
||||||
|
@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
|
||||||
ShardSegment dataset;
|
ShardSegment dataset;
|
||||||
dataset.Open(file_name, 4);
|
dataset.Open({file_name}, true, 4);
|
||||||
|
|
||||||
auto x = dataset.GetCategoryFields();
|
auto x = dataset.GetCategoryFields();
|
||||||
for (const auto &fields : x.second) {
|
for (const auto &fields : x.second) {
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
|
||||||
ShardSegment dataset;
|
ShardSegment dataset;
|
||||||
dataset.Open(file_name, 4);
|
dataset.Open({file_name}, true, 4);
|
||||||
|
|
||||||
auto x = dataset.GetCategoryFields();
|
auto x = dataset.GetCategoryFields();
|
||||||
for (const auto &fields : x.second) {
|
for (const auto &fields : x.second) {
|
||||||
|
@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
|
||||||
ShardSegment dataset;
|
ShardSegment dataset;
|
||||||
dataset.Open(file_name, 4);
|
dataset.Open({file_name}, true, 4);
|
||||||
|
|
||||||
auto x = dataset.GetCategoryFields();
|
auto x = dataset.GetCategoryFields();
|
||||||
for (const auto &fields : x.second) {
|
for (const auto &fields : x.second) {
|
||||||
|
@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
|
||||||
ShardSegment dataset;
|
ShardSegment dataset;
|
||||||
dataset.Open(file_name, 4);
|
dataset.Open({file_name}, true, 4);
|
||||||
|
|
||||||
auto x = dataset.GetCategoryFields();
|
auto x = dataset.GetCategoryFields();
|
||||||
for (const auto &fields : x.second) {
|
for (const auto &fields : x.second) {
|
||||||
|
|
|
@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
|
||||||
std::string filename = "./OneSample.shard01";
|
std::string filename = "./OneSample.shard01";
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(filename, 4);
|
MSRStatus ret = dataset.Open({filename}, true, 4);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
|
||||||
filename = "./imagenet.shard01";
|
filename = "./imagenet.shard01";
|
||||||
auto column_list = std::vector<std::string>{"label", "file_name", "data"};
|
auto column_list = std::vector<std::string>{"label", "file_name", "data"};
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
|
||||||
filename = "./imagenet.shard01";
|
filename = "./imagenet.shard01";
|
||||||
auto column_list = std::vector<std::string>{"label", "file_name"};
|
auto column_list = std::vector<std::string>{"label", "file_name"};
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
|
||||||
filename = "./imagenet.shard01";
|
filename = "./imagenet.shard01";
|
||||||
auto column_list = std::vector<std::string>{"label", "data"};
|
auto column_list = std::vector<std::string>{"label", "data"};
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
|
||||||
|
|
||||||
filename = "./TenSampleFortyShard.shard01";
|
filename = "./TenSampleFortyShard.shard01";
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
MSRStatus ret = dataset.Open(filename, 4);
|
MSRStatus ret = dataset.Open({filename}, true, 4);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
dataset.Launch();
|
dataset.Launch();
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
FILES_NUM = 4
|
FILES_NUM = 4
|
||||||
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
||||||
|
CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
|
||||||
|
CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
|
||||||
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||||
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
|
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
|
||||||
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
|
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
|
||||||
|
@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
|
||||||
os.remove("{}".format(x))
|
os.remove("{}".format(x))
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
|
|
||||||
def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
||||||
"""tutorial for cv minddataset."""
|
"""tutorial for cv minddataset."""
|
||||||
columns_list = ["data", "file_name", "label"]
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
|
||||||
assert num_iter == 20
|
assert num_iter == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||||
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 10
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
|
||||||
|
assert data_set.get_dataset_size() < 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||||
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter < 10
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
if os.path.exists(CV1_FILE_NAME):
|
||||||
|
os.remove(CV1_FILE_NAME)
|
||||||
|
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||||
|
if os.path.exists(CV2_FILE_NAME):
|
||||||
|
os.remove(CV2_FILE_NAME)
|
||||||
|
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
|
||||||
|
os.remove("{}.db".format(CV2_FILE_NAME))
|
||||||
|
writer = FileWriter(CV1_FILE_NAME, 1)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
|
"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(cv_schema_json, "CV1_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
writer = FileWriter(CV2_FILE_NAME, 1)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
|
"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(cv_schema_json, "CV2_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers)
|
||||||
|
assert data_set.get_dataset_size() == 30
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||||
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 30
|
||||||
|
if os.path.exists(CV1_FILE_NAME):
|
||||||
|
os.remove(CV1_FILE_NAME)
|
||||||
|
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||||
|
if os.path.exists(CV2_FILE_NAME):
|
||||||
|
os.remove(CV2_FILE_NAME)
|
||||||
|
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
|
||||||
|
os.remove("{}.db".format(CV2_FILE_NAME))
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
|
||||||
|
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
|
||||||
|
for x in range(FILES_NUM)]
|
||||||
|
for x in paths:
|
||||||
|
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||||
|
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||||
|
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
|
"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(cv_schema_json, "CV1_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers)
|
||||||
|
assert data_set.get_dataset_size() < 20
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||||
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter < 20
|
||||||
|
for x in paths:
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
|
|
||||||
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
|
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
|
||||||
"""tutorial for cv minderdataset."""
|
"""tutorial for cv minderdataset."""
|
||||||
columns_list = ["data", "file_name", "label"]
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.dataset as ds
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
CV_FILE_NAME = "./imagenet.mindrecord"
|
CV_FILE_NAME = "./imagenet.mindrecord"
|
||||||
|
CV1_FILE_NAME = "./imagenet1.mindrecord"
|
||||||
|
|
||||||
|
|
||||||
def create_cv_mindrecord(files_num):
|
def create_cv_mindrecord(files_num):
|
||||||
|
@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
|
||||||
writer.commit()
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def create_diff_schema_cv_mindrecord(files_num):
|
||||||
|
"""tutorial for cv dataset writer."""
|
||||||
|
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||||
|
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||||
|
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||||
|
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||||
|
writer.add_schema(cv_schema_json, "img_schema")
|
||||||
|
writer.add_index(["file_name_1", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
def create_diff_page_size_cv_mindrecord(files_num):
|
||||||
|
"""tutorial for cv dataset writer."""
|
||||||
|
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||||
|
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||||
|
writer.set_page_size(1<< 26) #64MB
|
||||||
|
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||||
|
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||||
|
writer.add_schema(cv_schema_json, "img_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
def test_cv_lack_json():
|
def test_cv_lack_json():
|
||||||
"""tutorial for cv minderdataset."""
|
"""tutorial for cv minderdataset."""
|
||||||
create_cv_mindrecord(1)
|
create_cv_mindrecord(1)
|
||||||
|
@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
os.remove("{}.db".format(CV_FILE_NAME))
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_different_schema():
|
||||||
|
create_cv_mindrecord(1)
|
||||||
|
create_diff_schema_cv_mindrecord(1)
|
||||||
|
columns_list = ["data", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||||
|
num_readers)
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
os.remove(CV_FILE_NAME)
|
||||||
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
os.remove(CV1_FILE_NAME)
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||||
|
|
||||||
|
def test_cv_minddataset_reader_different_page_size():
|
||||||
|
create_cv_mindrecord(1)
|
||||||
|
create_diff_page_size_cv_mindrecord(1)
|
||||||
|
columns_list = ["data", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||||
|
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||||
|
num_readers)
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
os.remove(CV_FILE_NAME)
|
||||||
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
os.remove(CV1_FILE_NAME)
|
||||||
|
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||||
|
|
|
@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
|
||||||
assert count == 10
|
assert count == 10
|
||||||
reader.close()
|
reader.close()
|
||||||
|
|
||||||
|
def test_cv_file_reader_file_list():
|
||||||
|
"""tutorial for cv file partial reader."""
|
||||||
|
reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)])
|
||||||
|
count = 0
|
||||||
|
for index, x in enumerate(reader.get_next()):
|
||||||
|
assert len(x) == 3
|
||||||
|
count = count + 1
|
||||||
|
logger.info("#item{}: {}".format(index, x))
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
def test_cv_file_reader_partial_tutorial():
|
def test_cv_file_reader_partial_tutorial():
|
||||||
"""tutorial for cv file partial reader."""
|
"""tutorial for cv file partial reader."""
|
||||||
reader = FileReader(CV_FILE_NAME + "0")
|
reader = FileReader(CV_FILE_NAME + "0")
|
||||||
|
|
Loading…
Reference in New Issue