forked from mindspore-Ecosystem/mindspore
mindrecord support read file list
This commit is contained in:
parent
a2d5ad5abe
commit
aa3f89e74f
|
@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
|||
}
|
||||
|
||||
std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
|
||||
(void)builder->SetDatasetFile(ToString(args["dataset_file"]));
|
||||
|
||||
bool load_dataset = ToBool(args["load_dataset"]);
|
||||
if (load_dataset == true) {
|
||||
(void)builder->SetDatasetFile({ToString(args["dataset_file"])});
|
||||
} else {
|
||||
(void)builder->SetDatasetFile(ToStringVector(args["dataset_file"]));
|
||||
}
|
||||
(void)builder->SetLoadDataset(load_dataset);
|
||||
std::vector<std::string> in_col_names;
|
||||
if (!args["columns_list"].is_none()) {
|
||||
in_col_names = ToStringVector(args["columns_list"]);
|
||||
|
|
|
@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
|
|||
});
|
||||
|
||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||
auto create = sampler.attr("_create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
|
||||
return count;
|
||||
});
|
||||
.def_static("get_num_rows",
|
||||
[](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||
auto create = sampler.attr("_create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||
.def_static("get_num_rows_and_classes",
|
||||
|
|
|
@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
|
|||
using mindrecord::ShardReader;
|
||||
|
||||
// Builder constructor. Creates the builder object.
|
||||
MindRecordOp::Builder::Builder() : build_dataset_file_("") {
|
||||
MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
|
||||
// Some arguments to the MindRecordOp constructor have a default argument that is taken
|
||||
// from the client config.
|
||||
// The user may choose to change these values for the construction of the StorageOp by
|
||||
|
@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
|
|||
"Building a MindRecordOp that has not provided a file.");
|
||||
}
|
||||
|
||||
new_mind_record_op = std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_rows_per_buffer_,
|
||||
build_dataset_file_, build_op_connector_queue_size_,
|
||||
build_columns_to_load_, build_operators_, build_block_reader_);
|
||||
new_mind_record_op = std::make_shared<MindRecordOp>(
|
||||
build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_,
|
||||
build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_);
|
||||
|
||||
RETURN_IF_NOT_OK(new_mind_record_op->Init());
|
||||
|
||||
|
@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
|
|||
Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); }
|
||||
|
||||
// Constructor of the MindRecordOp.
|
||||
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
|
||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer,
|
||||
std::vector<std::string> dataset_file, bool load_dataset, int32_t op_connector_queue_size,
|
||||
const std::vector<std::string> &columns_to_load,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader)
|
||||
: ParallelOp(num_mind_record_workers, op_connector_queue_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
dataset_file_(dataset_file),
|
||||
load_dataset_(load_dataset),
|
||||
columns_to_load_(columns_to_load),
|
||||
operators_(operators),
|
||||
num_mind_record_workers_(num_mind_record_workers),
|
||||
|
@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
|
|||
// Private helper method to encapsulate some common construction/reset tasks
|
||||
Status MindRecordOp::Init() {
|
||||
shard_reader_ = std::make_unique<ShardReader>();
|
||||
auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_);
|
||||
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
|
||||
block_reader_);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED,
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS,
|
||||
"MindRecordOp init failed. Error message: " + ErrnoToMessage(rc));
|
||||
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
|
@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_
|
||||
<< "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_
|
||||
out << "\n Dataset file : ";
|
||||
for (auto &file : dataset_file_) {
|
||||
out << file << " ";
|
||||
}
|
||||
out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_
|
||||
<< "\nNumber of buffers : " << buffers_needed_
|
||||
<< "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
||||
int64_t *count) {
|
||||
Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
|
||||
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
|
||||
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
||||
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count);
|
||||
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count);
|
||||
if (rc == MSRStatus::FAILED) {
|
||||
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
|
||||
}
|
||||
|
|
|
@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetDatasetFile(const std::string &file) {
|
||||
build_dataset_file_ = file;
|
||||
Builder &SetDatasetFile(const std::vector<std::string> &files) {
|
||||
build_dataset_file_ = files;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetLoadDataset(bool load_dataset) {
|
||||
build_load_dataset_ = load_dataset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status SanityCheck() const;
|
||||
|
||||
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
|
||||
|
@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
|
|||
int32_t builder_num_workers_;
|
||||
int32_t build_rows_per_buffer_;
|
||||
int32_t build_op_connector_queue_size_;
|
||||
std::string build_dataset_file_;
|
||||
std::vector<std::string> build_dataset_file_;
|
||||
bool build_load_dataset_;
|
||||
std::vector<std::string> build_columns_to_load_;
|
||||
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
|
||||
bool build_block_reader_;
|
||||
|
@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
|
|||
// @note The builder class should be used to call it
|
||||
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
|
||||
// @param rows_per_buffer - The requested number of rows per buffer
|
||||
// @param dataset_file - A shard file
|
||||
// @param dataset_file - dataset files
|
||||
// @param op_connector_queue_size - The output connector queue size
|
||||
// @param columns_to_load - The list of columns to use (column name)
|
||||
// @param operators - ShardOperators for Shuffle, Category, Sample
|
||||
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
|
||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector<std::string> dataset_file,
|
||||
bool load_dataset, int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader);
|
||||
|
||||
// Destructor
|
||||
|
@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
|
|||
// Getter method
|
||||
int32_t num_rows() const { return num_rows_; }
|
||||
|
||||
// Getter method
|
||||
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
||||
int64_t *count);
|
||||
static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
|
||||
const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||
|
||||
// Getter method
|
||||
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||
|
||||
// Getter method
|
||||
std::string dataset_file() const { return dataset_file_; }
|
||||
std::vector<std::string> dataset_file() const { return dataset_file_; }
|
||||
|
||||
// Getter method
|
||||
std::vector<std::string> columns_to_load() const { return columns_to_load_; }
|
||||
|
||||
bool block_reader() const { return block_reader_; }
|
||||
|
||||
bool load_dataset() const { return load_dataset_; }
|
||||
|
||||
Status Init();
|
||||
|
||||
Status SetColumnsBlob();
|
||||
|
@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
|
|||
Status FetchBlockBuffer(const int32_t &buffer_id);
|
||||
|
||||
int32_t rows_per_buffer_; // The number of requested rows per buffer.
|
||||
std::string dataset_file_; // A dataset file
|
||||
std::vector<std::string> dataset_file_; // dataset files
|
||||
bool load_dataset_; // load dataset from single file or not
|
||||
std::vector<std::string> columns_to_load_; // Columns to load from dataset
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
|
||||
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
|
||||
|
|
|
@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
|
|||
case IO_FAILED:
|
||||
return "io operate failed";
|
||||
break;
|
||||
case MATCH_HEADER_FAILED:
|
||||
return "match header failed";
|
||||
break;
|
||||
default:
|
||||
return "invalid error no";
|
||||
}
|
||||
|
|
|
@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
|
|||
void BindShardReader(const py::module *m) {
|
||||
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
|
||||
.def(py::init<>())
|
||||
.def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector<std::string> &,
|
||||
.def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
|
||||
const std::vector<std::string> &,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
||||
ShardReader::OpenPy)
|
||||
.def("launch", &ShardReader::Launch)
|
||||
|
@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
|
|||
void BindShardSegment(py::module *m) {
|
||||
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
|
||||
.def(py::init<>())
|
||||
.def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector<std::string> &,
|
||||
.def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
|
||||
const std::vector<std::string> &,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
||||
ShardSegment::OpenPy)
|
||||
.def("get_category_fields",
|
||||
|
|
|
@ -72,7 +72,8 @@ enum MSRStatus {
|
|||
ILLEGAL_PARAMETERS,
|
||||
GET_PAGE_BY_GROUP_ID_FAILED,
|
||||
GET_SYSTEM_STATE_FAILED,
|
||||
IO_FAILED
|
||||
IO_FAILED,
|
||||
MATCH_HEADER_FAILED
|
||||
};
|
||||
|
||||
// convert error no to string message
|
||||
|
|
|
@ -35,10 +35,11 @@ class ShardHeader {
|
|||
public:
|
||||
ShardHeader();
|
||||
|
||||
MSRStatus Build(const std::string &file_path);
|
||||
|
||||
~ShardHeader() = default;
|
||||
|
||||
MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
|
||||
|
||||
static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
|
||||
/// \brief add the schema and save it
|
||||
/// \param[in] schema the schema needs to be added
|
||||
/// \return the last schema's id
|
||||
|
@ -126,7 +127,7 @@ class ShardHeader {
|
|||
MSRStatus FileToPages(const std::string dump_file_name);
|
||||
|
||||
private:
|
||||
MSRStatus InitializeHeader(const std::vector<json> &headers);
|
||||
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
|
||||
|
||||
/// \brief get the headers from all the shard data
|
||||
/// \param[in] the shard data real path
|
||||
|
@ -137,9 +138,9 @@ class ShardHeader {
|
|||
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
|
||||
|
||||
/// \brief check the binary file status
|
||||
MSRStatus CheckFileStatus(const std::string &path);
|
||||
static MSRStatus CheckFileStatus(const std::string &path);
|
||||
|
||||
std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
|
||||
static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
|
||||
|
||||
void ParseHeader(const json &header);
|
||||
|
||||
|
@ -149,7 +150,7 @@ class ShardHeader {
|
|||
|
||||
MSRStatus CheckIndexField(const std::string &field, const json &schema);
|
||||
|
||||
void ParsePage(const json &page);
|
||||
void ParsePage(const json &page, int shard_index, bool load_dataset);
|
||||
|
||||
MSRStatus ParseStatistics(const json &statistics);
|
||||
|
||||
|
|
|
@ -68,23 +68,25 @@ class ShardReader {
|
|||
virtual ~ShardReader();
|
||||
|
||||
/// \brief open files and initialize reader, c++ API
|
||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
||||
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||
/// \param[in] load_dataset load dataset from single file or not
|
||||
/// \param[in] n_consumer number of threads when reading
|
||||
/// \param[in] selected_columns column list to be populated
|
||||
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
||||
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus Open(const std::string &file_path, int n_consumer = 4,
|
||||
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
|
||||
const std::vector<std::string> &selected_columns = {},
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false);
|
||||
|
||||
/// \brief open files and initialize reader, python API
|
||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
||||
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||
/// \param[in] load_dataset load dataset from single file or not
|
||||
/// \param[in] n_consumer number of threads when reading
|
||||
/// \param[in] selected_columns column list to be populated
|
||||
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4,
|
||||
MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
|
||||
const std::vector<std::string> &selected_columns = {},
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
|
||||
|
||||
|
@ -114,11 +116,13 @@ class ShardReader {
|
|||
int GetShardCount() const;
|
||||
|
||||
/// \brief get the number of rows in database
|
||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
||||
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||
/// \param[in] load_dataset load dataset from single file or not
|
||||
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
|
||||
/// \param[out] count # of rows
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||
MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||
const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||
|
||||
/// \brief shuffle task with incremental seed
|
||||
/// \return void
|
||||
|
@ -220,7 +224,7 @@ class ShardReader {
|
|||
std::vector<std::vector<json>> &column_values);
|
||||
|
||||
/// \brief initialize reader
|
||||
MSRStatus Init(const std::string &file_path);
|
||||
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
|
||||
|
||||
/// \brief validate column list
|
||||
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
|
||||
|
@ -292,8 +296,9 @@ class ShardReader {
|
|||
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
||||
|
||||
/// \brief get number of classes
|
||||
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
|
||||
int64_t GetNumClasses(const std::string &category_field);
|
||||
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
|
||||
/// \brief get exactly blob fields data by indices
|
||||
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
|
||||
std::vector<uint32_t> &ordered_selected_columns_index);
|
||||
|
|
|
@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
|
|||
write_success_(true) {}
|
||||
|
||||
MSRStatus ShardIndexGenerator::Build() {
|
||||
auto ret = ShardHeader::BuildSingleHeader(file_path_);
|
||||
if (ret.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto json_header = ret.second;
|
||||
|
||||
auto ret2 = GetParentDir(file_path_);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : json_header["shard_addresses"]) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
ShardHeader header = ShardHeader();
|
||||
if (header.Build(file_path_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build shard schema failed.";
|
||||
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = header;
|
||||
|
|
|
@ -47,20 +47,55 @@ ShardReader::ShardReader() {
|
|||
block_reader_ = false;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::Init(const std::string &file_path) {
|
||||
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
||||
if (!IsLegalFile(file_path)) {
|
||||
return FAILED;
|
||||
return {FAILED, {}};
|
||||
}
|
||||
ShardHeader sh = ShardHeader();
|
||||
if (sh.Build(file_path) == FAILED) {
|
||||
return FAILED;
|
||||
auto ret = ShardHeader::BuildSingleHeader(file_path);
|
||||
if (ret.first != SUCCESS) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
header_size_ = shard_header_->GetHeaderSize();
|
||||
page_size_ = shard_header_->GetPageSize();
|
||||
file_paths_ = shard_header_->GetShardAddresses();
|
||||
auto header = ret.second;
|
||||
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
||||
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
||||
return {SUCCESS, header["shard_addresses"]};
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||
std::string file_path = file_paths[0];
|
||||
json first_meta_data = json();
|
||||
auto ret = GetMeta(file_path, first_meta_data);
|
||||
if (ret.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (file_paths.size() == 1 && load_dataset == true) {
|
||||
auto ret2 = GetParentDir(file_path);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : ret.second) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
file_paths_ = real_addresses;
|
||||
} else if (file_paths.size() >= 1 && load_dataset == false) {
|
||||
file_paths_ = file_paths;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error in parameter file_path or load_dataset.";
|
||||
return FAILED;
|
||||
}
|
||||
for (const auto &file : file_paths_) {
|
||||
json meta_data = json();
|
||||
auto ret1 = GetMeta(file, meta_data);
|
||||
if (ret1.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (meta_data != first_meta_data) {
|
||||
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
|
||||
return FAILED;
|
||||
}
|
||||
sqlite3 *db = nullptr;
|
||||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
|
@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
|
|||
}
|
||||
database_paths_.push_back(db);
|
||||
}
|
||||
|
||||
ShardHeader sh = ShardHeader();
|
||||
if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
header_size_ = shard_header_->GetHeaderSize();
|
||||
page_size_ = shard_header_->GetPageSize();
|
||||
num_rows_ = 0;
|
||||
auto row_group_summary = ReadRowGroupSummary();
|
||||
for (const auto &rg : row_group_summary) {
|
||||
|
@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
fs->close();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
json tmp;
|
||||
if (!columns.empty()) {
|
||||
|
@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) {
|
||||
ShardHeader sh = ShardHeader();
|
||||
if (sh.Build(file_path) == FAILED) {
|
||||
return -1;
|
||||
}
|
||||
auto header = std::make_shared<ShardHeader>(sh);
|
||||
auto file_paths = header->GetShardAddresses();
|
||||
auto shard_count = file_paths.size();
|
||||
auto index_fields = header->GetFields();
|
||||
int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
||||
auto shard_count = file_paths_.size();
|
||||
auto index_fields = shard_header_->GetFields();
|
||||
|
||||
std::map<std::string, int64_t> map_schema_id_fields;
|
||||
for (auto &field : index_fields) {
|
||||
|
@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
|||
std::set<std::string> categories;
|
||||
for (int x = 0; x < shard_count; x++) {
|
||||
sqlite3 *db = nullptr;
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
if (SQLITE_OK != rc) {
|
||||
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
|
||||
return -1;
|
||||
|
@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
|||
return categories.size();
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op,
|
||||
int64_t *count) {
|
||||
if (Init(file_path) == FAILED) {
|
||||
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
|
||||
if (SUCCESS != Init(file_paths, load_dataset)) {
|
||||
return FAILED;
|
||||
}
|
||||
int64_t num_samples = num_rows_;
|
||||
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
std::string category_field = category_op->GetCategoryField();
|
||||
auto num_classes = GetNumClasses(file_path, category_field);
|
||||
auto num_classes = GetNumClasses(category_field);
|
||||
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
num_samples = op->GetNumSamples(num_rows_, 0);
|
||||
|
@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
||||
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
|
||||
const std::vector<std::string> &selected_columns,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) {
|
||||
// Open file and set header by ShardReader
|
||||
if (Init(file_path) == FAILED) {
|
||||
return FAILED;
|
||||
auto ret = Init(file_paths, load_dataset);
|
||||
if (SUCCESS != ret) {
|
||||
return ret;
|
||||
}
|
||||
auto thread_limit = GetMaxThreadNum();
|
||||
if (n_consumer > thread_limit) {
|
||||
|
@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer,
|
||||
MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
|
||||
const std::vector<std::string> &selected_columns,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||
// Open file and set header by ShardReader
|
||||
if (Init(file_path) == FAILED) {
|
||||
if (SUCCESS != Init(file_paths, load_dataset)) {
|
||||
return FAILED;
|
||||
}
|
||||
// should remove blob field from selected_columns when call from python
|
||||
|
|
|
@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
if (!IsLegalFile(path)) {
|
||||
return FAILED;
|
||||
}
|
||||
ShardHeader sh = ShardHeader();
|
||||
if (sh.Build(path) == FAILED) {
|
||||
auto ret1 = ShardHeader::BuildSingleHeader(path);
|
||||
if (ret1.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
auto paths = shard_header_->GetShardAddresses();
|
||||
auto json_header = ret1.second;
|
||||
auto ret2 = GetParentDir(path);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : json_header["shard_addresses"]) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
ShardHeader header = ShardHeader();
|
||||
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(header);
|
||||
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
|
||||
if (ret == FAILED) {
|
||||
return FAILED;
|
||||
|
@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
if (ret == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
ret = Open(paths, true);
|
||||
ret = Open(json_header["shard_addresses"], true);
|
||||
if (ret == FAILED) {
|
||||
MS_LOG(ERROR) << "Open file failed";
|
||||
return FAILED;
|
||||
|
|
|
@ -35,8 +35,9 @@ namespace mindrecord {
|
|||
std::atomic<bool> thread_status(false);
|
||||
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
|
||||
|
||||
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
|
||||
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
|
||||
shard_count_ = headers.size();
|
||||
int shard_index = 0;
|
||||
bool first = true;
|
||||
for (const auto &header : headers) {
|
||||
if (first) {
|
||||
|
@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
|
|||
header_size_ = header["header_size"].get<uint64_t>();
|
||||
page_size_ = header["page_size"].get<uint64_t>();
|
||||
}
|
||||
ParsePage(header["page"]);
|
||||
ParsePage(header["page"], shard_index, load_dataset);
|
||||
shard_index++;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
|
|||
return {SUCCESS, json_header};
|
||||
}
|
||||
|
||||
MSRStatus ShardHeader::Build(const std::string &file_path) {
|
||||
std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
|
||||
auto ret = ValidateHeader(file_path);
|
||||
if (SUCCESS != ret.first) {
|
||||
return FAILED;
|
||||
return {FAILED, json()};
|
||||
}
|
||||
json main_header = ret.second;
|
||||
json addresses = main_header["shard_addresses"];
|
||||
vector<string> real_addresses;
|
||||
auto ret1 = GetParentDir(file_path);
|
||||
if (SUCCESS != ret1.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::string parent_dir = ret1.second;
|
||||
json raw_header = ret.second;
|
||||
json header = {{"shard_addresses", raw_header["shard_addresses"]},
|
||||
{"header_size", raw_header["header_size"]},
|
||||
{"page_size", raw_header["page_size"]},
|
||||
{"index_fields", raw_header["index_fields"]},
|
||||
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
|
||||
{"schema", raw_header["schema"][0]["schema"]},
|
||||
{"version", raw_header["version"]}};
|
||||
return {SUCCESS, header};
|
||||
}
|
||||
|
||||
for (const auto &addr : addresses) {
|
||||
std::string absolute_path = parent_dir + string(addr);
|
||||
real_addresses.emplace_back(absolute_path);
|
||||
}
|
||||
MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||
uint32_t thread_num = std::thread::hardware_concurrency();
|
||||
if (thread_num == 0) thread_num = kThreadNumber;
|
||||
uint32_t work_thread_num = 0;
|
||||
uint32_t addr_count = real_addresses.size();
|
||||
int group_num = ceil(addr_count * 1.0 / thread_num);
|
||||
uint32_t shard_count = file_paths.size();
|
||||
int group_num = ceil(shard_count * 1.0 / thread_num);
|
||||
std::vector<std::thread> thread_set(thread_num);
|
||||
std::vector<json> headers(addr_count);
|
||||
std::vector<json> headers(shard_count);
|
||||
for (uint32_t x = 0; x < thread_num; ++x) {
|
||||
int start_num = x * group_num;
|
||||
int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num;
|
||||
int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
|
||||
if (start_num >= end_num) {
|
||||
continue;
|
||||
}
|
||||
|
||||
thread_set[x] =
|
||||
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses);
|
||||
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
|
||||
work_thread_num++;
|
||||
}
|
||||
|
||||
|
@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
|
|||
thread_status = false;
|
||||
return FAILED;
|
||||
}
|
||||
if (SUCCESS != InitializeHeader(headers)) {
|
||||
if (SUCCESS != InitializeHeader(headers, load_dataset)) {
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ShardHeader::ParsePage(const json &pages) {
|
||||
void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
|
||||
// set shard_index when load_dataset is false
|
||||
if (pages_.empty() && shard_count_ <= kMaxShardCount) {
|
||||
pages_.resize(shard_count_);
|
||||
}
|
||||
|
@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
|
|||
|
||||
std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
|
||||
end_row_id, row_group_ids, page_size);
|
||||
pages_[shard_id].push_back(std::move(parsed_page));
|
||||
if (load_dataset == true) {
|
||||
pages_[shard_id].push_back(std::move(parsed_page));
|
||||
} else {
|
||||
pages_[shard_index].push_back(std::move(parsed_page));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
|
|||
|
||||
std::string line;
|
||||
while (std::getline(page_in_handle, line)) {
|
||||
ParsePage(json::parse(line));
|
||||
ParsePage(json::parse(line), -1, true);
|
||||
}
|
||||
|
||||
page_in_handle.close();
|
||||
|
|
|
@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
|
|||
A source dataset that reads from shard files and database.
|
||||
|
||||
Args:
|
||||
dataset_file (str): one of file names in dataset.
|
||||
dataset_file (str, list[str]): One of file names or file list in dataset.
|
||||
columns_list (list[str], optional): List of columns to be read (default=None).
|
||||
num_parallel_workers (int, optional): The number of readers (default=None).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
|
@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
|
|||
shuffle=None, num_shards=None, shard_id=None,
|
||||
block_reader=False, sampler=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
if isinstance(dataset_file, list):
|
||||
self.load_dataset = False
|
||||
else:
|
||||
self.load_dataset = True
|
||||
self.dataset_file = dataset_file
|
||||
self.columns_list = columns_list
|
||||
self.global_shuffle = shuffle
|
||||
|
@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_file"] = self.dataset_file
|
||||
args["load_dataset"] = self.load_dataset
|
||||
args["columns_list"] = self.columns_list
|
||||
args["global_shuffle"] = self.global_shuffle
|
||||
args["partitions"] = self.partitions
|
||||
|
@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
|
||||
num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler)
|
||||
if self.load_dataset:
|
||||
dataset_file = [self.dataset_file]
|
||||
else:
|
||||
dataset_file = self.dataset_file
|
||||
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler)
|
||||
if self.partitions is not None and self.partitions[0] > 0:
|
||||
if num_rows % self.partitions[0] == 0:
|
||||
num_rows = num_rows // self.partitions[0]
|
||||
|
|
|
@ -529,8 +529,11 @@ def check_minddataset(method):
|
|||
dataset_file = param_dict.get('dataset_file')
|
||||
if dataset_file is None:
|
||||
raise ValueError("dataset_file is not provided.")
|
||||
check_dataset_file(dataset_file)
|
||||
|
||||
if isinstance(dataset_file, list):
|
||||
for f in dataset_file:
|
||||
check_dataset_file(f)
|
||||
else:
|
||||
check_dataset_file(dataset_file)
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
||||
check_param_type(nreq_param_list, param_dict, list)
|
||||
|
|
|
@ -28,7 +28,7 @@ class FileReader:
|
|||
Class to read MindRecord File series.
|
||||
|
||||
Args:
|
||||
file_name (str): File name of MindRecord File.
|
||||
file_name (str, list[str]): One of MindRecord File or file list.
|
||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
||||
It should not be smaller than 1 or larger than the number of CPU.
|
||||
columns (list[str], optional): List of fields which correspond data would be read (default=None).
|
||||
|
@ -38,8 +38,11 @@ class FileReader:
|
|||
ParamValueError: If file_name, num_consumer or columns is invalid.
|
||||
"""
|
||||
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
|
||||
check_filename(file_name)
|
||||
self._file_name = file_name
|
||||
if isinstance(file_name, list):
|
||||
for f in file_name:
|
||||
check_filename(f)
|
||||
else:
|
||||
check_filename(file_name)
|
||||
|
||||
if num_consumer is not None:
|
||||
if isinstance(num_consumer, int):
|
||||
|
|
|
@ -28,7 +28,7 @@ class MindPage:
|
|||
Class to read MindRecord File series in pagination.
|
||||
|
||||
Args:
|
||||
file_name (str): File name of MindRecord File.
|
||||
file_name (str): One of MindRecord File or file list.
|
||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
||||
It should not be smaller than 1 or larger than the number of CPU.
|
||||
|
||||
|
@ -37,8 +37,11 @@ class MindPage:
|
|||
MRMInitSegmentError: If failed to initialize ShardSegment.
|
||||
"""
|
||||
def __init__(self, file_name, num_consumer=4):
|
||||
check_filename(file_name)
|
||||
self._file_name = file_name
|
||||
if isinstance(file_name, list):
|
||||
for f in file_name:
|
||||
check_filename(f)
|
||||
else:
|
||||
check_filename(file_name)
|
||||
|
||||
if num_consumer is not None:
|
||||
if isinstance(num_consumer, int):
|
||||
|
|
|
@ -35,7 +35,7 @@ class ShardReader:
|
|||
Open file and prepare to read MindRecord File.
|
||||
|
||||
Args:
|
||||
file_name (str): File name of MindRecord File.
|
||||
file_name (str, list[str]): File names of MindRecord File.
|
||||
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
||||
columns (list[str]): List of fields which correspond data would be read.
|
||||
operator(int): Reserved parameter for operators. Default: None.
|
||||
|
@ -48,7 +48,12 @@ class ShardReader:
|
|||
"""
|
||||
columns = columns if columns else []
|
||||
operator = operator if operator else []
|
||||
ret = self._reader.open(file_name, num_consumer, columns, operator)
|
||||
if isinstance(file_name, list):
|
||||
load_dataset = False
|
||||
else:
|
||||
load_dataset = True
|
||||
file_name = [file_name]
|
||||
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
|
||||
if ret != ms.MSRStatus.SUCCESS:
|
||||
logger.error("Failed to open {}.".format(file_name))
|
||||
raise MRMOpenError
|
||||
|
|
|
@ -40,7 +40,7 @@ class ShardSegment:
|
|||
Initialize the ShardSegment.
|
||||
|
||||
Args:
|
||||
file_name (str): File name of MindRecord File.
|
||||
file_name (str, list[str]): File names of MindRecord File.
|
||||
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
|
||||
columns (list[str]): List of fields which correspond data would be read.
|
||||
operator(int): Reserved parameter for operators. Default: None.
|
||||
|
@ -53,7 +53,12 @@ class ShardSegment:
|
|||
"""
|
||||
self._columns = columns if columns else []
|
||||
operator = operator if operator else []
|
||||
ret = self._segment.open(file_name, num_consumer, self._columns, operator)
|
||||
if isinstance(file_name, list):
|
||||
load_dataset = False
|
||||
else:
|
||||
load_dataset = True
|
||||
file_name = [file_name]
|
||||
ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
|
||||
if ret != SUCCESS:
|
||||
logger.error("Failed to open {}.".format(file_name))
|
||||
raise MRMOpenError
|
||||
|
|
|
@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list);
|
||||
|
@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list)
|
||||
|
@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list)
|
||||
|
@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list)
|
||||
|
@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list);
|
||||
|
@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetBlockReader()
|
||||
|
@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
|
|||
|
||||
std::shared_ptr<MindRecordOp> my_mindrecord_op;
|
||||
MindRecordOp::Builder builder;
|
||||
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
|
||||
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
|
||||
.SetLoadDataset(true)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetNumMindRecordWorkers(4)
|
||||
.SetColumnsToLoad(column_list);
|
||||
|
|
|
@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
|
|||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||
ops.push_back(std::make_shared<ShardSample>(kSampleCount));
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
|||
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
|
|||
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
|
|||
ASSERT_TRUE(partitions.second == 2);
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
|
|||
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name},true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
|||
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name},true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
|
|||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
|
|||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 16, column_list, ops);
|
||||
dataset.Open({file_name}, true, 16, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
|||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
|
|||
ops.push_back(std::make_shared<ShardSample>(kSampleSize));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
|
|||
ops.push_back(std::make_shared<ShardSample>(35));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
|
|||
ops.push_back(std::make_shared<ShardShuffle>(1));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
ShardReader compare_dataset;
|
||||
compare_dataset.Open(file_name, 4, column_list);
|
||||
compare_dataset.Open({file_name},true, 4, column_list);
|
||||
compare_dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
|
|||
ops.push_back(std::make_shared<ShardShuffle>(21));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
|||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
|
|||
ops.push_back(std::make_shared<ShardCategory>(categories));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name},true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
|
|||
ops.push_back(std::make_shared<ShardShuffle>(100));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
int i = 0;
|
||||
|
|
|
@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
|
|||
auto column_list = std::vector<std::string>{"file_name"};
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list);
|
||||
dataset.Open({file_name}, true, 4, column_list);
|
||||
dataset.Launch();
|
||||
|
||||
while (true) {
|
||||
|
@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
|
|||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||
ops.push_back(std::make_shared<ShardSample>(17));
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, 4, column_list, ops);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops);
|
||||
dataset.Launch();
|
||||
|
||||
while (true) {
|
||||
|
@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
|
|||
ops.push_back(std::make_shared<ShardSample>(3));
|
||||
ShardReader dataset;
|
||||
const bool kBlockReader = true;
|
||||
dataset.Open(file_name, 4, column_list, ops, kBlockReader);
|
||||
dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader);
|
||||
dataset.Launch();
|
||||
|
||||
while (true) {
|
||||
|
@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
|
|||
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
||||
std::string file_name = "./imagenet.shard01";
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name);
|
||||
dataset.Open({file_name}, true);
|
||||
dataset.Launch();
|
||||
|
||||
while (true) {
|
||||
|
@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
auto column_list = std::vector<std::string>{"label"};
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
auto column_list = std::vector<std::string>{"file_namex"};
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
|
|||
MS_LOG(INFO) << FormatInfo("Test shard version");
|
||||
std::string file_name = "./imagenet.shard01";
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(file_name, 4);
|
||||
MSRStatus ret = dataset.Open({file_name}, true, 4);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
|
|||
auto column_list = std::vector<std::string>{"file_name"};
|
||||
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(file_name, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
|
|||
auto column_list = std::vector<std::string>{"file_name"};
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open(file_name, -481565535, column_list);
|
||||
dataset.Open({file_name}, true, -481565535, column_list);
|
||||
dataset.Launch();
|
||||
|
||||
while (true) {
|
||||
|
|
|
@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
|
||||
ShardSegment dataset;
|
||||
dataset.Open(file_name, 4);
|
||||
dataset.Open({file_name}, true, 4);
|
||||
|
||||
auto x = dataset.GetCategoryFields();
|
||||
for (const auto &fields : x.second) {
|
||||
|
@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
|
||||
ShardSegment dataset;
|
||||
dataset.Open(file_name, 4);
|
||||
dataset.Open({file_name}, true, 4);
|
||||
|
||||
auto x = dataset.GetCategoryFields();
|
||||
for (const auto &fields : x.second) {
|
||||
|
@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
|
||||
ShardSegment dataset;
|
||||
dataset.Open(file_name, 4);
|
||||
dataset.Open({file_name}, true, 4);
|
||||
|
||||
auto x = dataset.GetCategoryFields();
|
||||
for (const auto &fields : x.second) {
|
||||
|
@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
|
||||
ShardSegment dataset;
|
||||
dataset.Open(file_name, 4);
|
||||
dataset.Open({file_name}, true, 4);
|
||||
|
||||
auto x = dataset.GetCategoryFields();
|
||||
for (const auto &fields : x.second) {
|
||||
|
@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
|
|||
std::string file_name = "./imagenet.shard01";
|
||||
|
||||
ShardSegment dataset;
|
||||
dataset.Open(file_name, 4);
|
||||
dataset.Open({file_name}, true, 4);
|
||||
|
||||
auto x = dataset.GetCategoryFields();
|
||||
for (const auto &fields : x.second) {
|
||||
|
|
|
@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
|
|||
std::string filename = "./OneSample.shard01";
|
||||
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(filename, 4);
|
||||
MSRStatus ret = dataset.Open({filename}, true, 4);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
|
|||
filename = "./imagenet.shard01";
|
||||
auto column_list = std::vector<std::string>{"label", "file_name", "data"};
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
|
|||
filename = "./imagenet.shard01";
|
||||
auto column_list = std::vector<std::string>{"label", "file_name"};
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
|
|||
filename = "./imagenet.shard01";
|
||||
auto column_list = std::vector<std::string>{"label", "data"};
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(filename, 4, column_list);
|
||||
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
|
|||
|
||||
filename = "./TenSampleFortyShard.shard01";
|
||||
ShardReader dataset;
|
||||
MSRStatus ret = dataset.Open(filename, 4);
|
||||
MSRStatus ret = dataset.Open({filename}, true, 4);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
dataset.Launch();
|
||||
|
||||
|
|
|
@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
|
|||
|
||||
FILES_NUM = 4
|
||||
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
||||
CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
|
||||
CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
|
||||
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
|
||||
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
|
||||
|
@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
|
|||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
|
|||
assert num_iter == 20
|
||||
|
||||
|
||||
def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() < 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter < 10
|
||||
|
||||
def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
if os.path.exists(CV1_FILE_NAME):
|
||||
os.remove(CV1_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
if os.path.exists(CV2_FILE_NAME):
|
||||
os.remove(CV2_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV2_FILE_NAME))
|
||||
writer = FileWriter(CV1_FILE_NAME, 1)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "CV1_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
writer = FileWriter(CV2_FILE_NAME, 1)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "CV2_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() == 30
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 30
|
||||
if os.path.exists(CV1_FILE_NAME):
|
||||
os.remove(CV1_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
if os.path.exists(CV2_FILE_NAME):
|
||||
os.remove(CV2_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV2_FILE_NAME))
|
||||
|
||||
def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
|
||||
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "CV1_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() < 20
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter < 20
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.dataset as ds
|
|||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
CV_FILE_NAME = "./imagenet.mindrecord"
|
||||
CV1_FILE_NAME = "./imagenet1.mindrecord"
|
||||
|
||||
|
||||
def create_cv_mindrecord(files_num):
|
||||
|
@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
|
|||
writer.commit()
|
||||
|
||||
|
||||
def create_diff_schema_cv_mindrecord(files_num):
|
||||
"""tutorial for cv dataset writer."""
|
||||
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name_1", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
def create_diff_page_size_cv_mindrecord(files_num):
|
||||
"""tutorial for cv dataset writer."""
|
||||
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
writer.set_page_size(1<< 26) #64MB
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
def test_cv_lack_json():
|
||||
"""tutorial for cv minderdataset."""
|
||||
create_cv_mindrecord(1)
|
||||
|
@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
|
|||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
def test_cv_minddataset_reader_different_schema():
|
||||
create_cv_mindrecord(1)
|
||||
create_diff_schema_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
os.remove(CV1_FILE_NAME)
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
|
||||
def test_cv_minddataset_reader_different_page_size():
|
||||
create_cv_mindrecord(1)
|
||||
create_diff_page_size_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
os.remove(CV1_FILE_NAME)
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
|
|
|
@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
|
|||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
def test_cv_file_reader_file_list():
|
||||
"""tutorial for cv file partial reader."""
|
||||
reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 10
|
||||
|
||||
def test_cv_file_reader_partial_tutorial():
|
||||
"""tutorial for cv file partial reader."""
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
|
|
Loading…
Reference in New Issue