Move python tree transformation into cpp pass.
This commit is contained in:
parent
a4048e192c
commit
57f3732ac3
|
@ -487,6 +487,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
(void)builder->SetInColNames(in_col_names);
|
||||
} else if (key == "output_columns") {
|
||||
(void)builder->SetOutColNames(ToStringVector(value));
|
||||
} else if (key == "columns_order") {
|
||||
(void)builder->SetColOrder(ToStringVector(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "prefetch_size") {
|
||||
|
@ -835,6 +837,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetColumnsToLoad(columns_to_load);
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
||||
schema_exists = true;
|
||||
} else if (key == "num_samples") {
|
||||
|
@ -1225,6 +1229,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetTotalRows(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
|
@ -1314,6 +1320,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
|
@ -68,8 +69,45 @@ Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOp::RemoveChild(std::shared_ptr<DatasetOp> child) {
|
||||
if (operator_id_ == kInvalidOperatorId) {
|
||||
std::string err_msg(
|
||||
"Cannot remove child node. Tree node connections can only"
|
||||
"be made if the node belongs to a tree.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// disallow relationships with other trees
|
||||
if (tree_ != child->tree_) {
|
||||
std::string err_msg(
|
||||
"Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end());
|
||||
child->RemoveParent(this);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
|
||||
for (auto &prev_parent : this->parent_) {
|
||||
RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this()));
|
||||
RETURN_IF_NOT_OK(prev_parent->AddChild(to_add));
|
||||
}
|
||||
RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this()));
|
||||
if (tree_->root()->id() == this->id()) {
|
||||
tree_->AssignRoot(to_add);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds a parent operator to this operator
|
||||
void DatasetOp::AddParent(const DatasetOp *parent) { parent_.push_back(parent); }
|
||||
void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
|
||||
|
||||
// Removes a parent operator from this operator
|
||||
void DatasetOp::RemoveParent(DatasetOp *parent) {
|
||||
parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end());
|
||||
}
|
||||
|
||||
// Getter function to get a shared pointer to our childAdds a operator to become our child.
|
||||
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
|
||||
|
|
|
@ -64,10 +64,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @param child - shared pointer to the child to add.
|
||||
Status AddChild(std::shared_ptr<DatasetOp> child);
|
||||
|
||||
// Remove a operator from our children.
|
||||
// @param child - shared pointer to the child to remove.
|
||||
Status RemoveChild(std::shared_ptr<DatasetOp> child);
|
||||
|
||||
// Getter function to get a shared pointer to our child
|
||||
// @param child_index - An operator can have n children. Indicates choose which child to return.
|
||||
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
|
||||
|
||||
// Inserts a operator as the parent current op.
|
||||
// Inserted op will become the sole parent of the current op.
|
||||
// The existing parent of the current op will be transferred to the inserted op.
|
||||
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
|
||||
|
||||
// Creates the connector within this operator
|
||||
// @param num_producers - number of threads that write into this connector
|
||||
// @param num_consumers - number of threads that read from this connector
|
||||
|
@ -261,7 +270,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// Adds a parent operator to this operator
|
||||
// @notes External callers do not have access to this function.
|
||||
// @param parent - The parent node to add
|
||||
void AddParent(const DatasetOp *parent);
|
||||
void AddParent(DatasetOp *parent);
|
||||
|
||||
// Removes a parent operator from this operator
|
||||
// @notes External callers do not have access to this function.
|
||||
// @param parent - The parent node to remove
|
||||
void RemoveParent(DatasetOp *parent);
|
||||
|
||||
// A helper function for providing an assignment of the column name map.
|
||||
// This grabs the map from child 0 and assigns it into this op.
|
||||
|
@ -270,7 +284,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
Status AssignColMapFromChild();
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
|
||||
std::vector<const DatasetOp *> parent_; // Parent nodes. No ownership and read-only
|
||||
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
|
||||
int32_t oc_queue_size_; // Capacity for each out_connector_
|
||||
int32_t operator_id_; // Generated id for the node
|
||||
ExecutionTree *tree_; // Back pointer to our tree.
|
||||
|
|
|
@ -54,19 +54,20 @@ Status MapOp::Builder::sanityCheck() const {
|
|||
Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(sanityCheck());
|
||||
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
|
||||
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_,
|
||||
build_perf_mode_);
|
||||
std::move(build_tensor_funcs_), std::move(build_col_order_), build_num_workers_,
|
||||
build_op_connector_size_, build_perf_mode_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of MapOp
|
||||
MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_funcs, int32_t num_workers, int32_t op_connector_size,
|
||||
bool perf_mode)
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_funcs, const std::vector<std::string> &columns_order,
|
||||
int32_t num_workers, int32_t op_connector_size, bool perf_mode)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
tfuncs_(std::move(tensor_funcs)),
|
||||
in_columns_(in_col_names),
|
||||
out_columns_(out_col_names),
|
||||
columns_order_(columns_order),
|
||||
perf_mode_(perf_mode) {
|
||||
// If caller didn't specify the out_col_names, assume they are same as the in_columns.
|
||||
if (out_columns_.empty() || out_columns_[0].empty()) {
|
||||
|
|
|
@ -93,6 +93,13 @@ class MapOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetColOrder(const std::vector<std::string> &col_order_) {
|
||||
build_col_order_ = col_order_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
|
@ -123,6 +130,7 @@ class MapOp : public ParallelOp {
|
|||
std::vector<std::string> build_in_col_names_;
|
||||
std::vector<std::string> build_out_col_names_;
|
||||
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
|
||||
std::vector<std::string> build_col_order_;
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
bool build_perf_mode_; // Default true.
|
||||
|
@ -137,11 +145,12 @@ class MapOp : public ParallelOp {
|
|||
// @param in_col_names A list of input column names (should match the input/output \p tensorFuncs).
|
||||
// @param out_col_names A list of output column names (should match the input/output \p tensorFuncs).
|
||||
// @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data.
|
||||
// @param columns_order names A full list of column names (should match the whole dataset view post \p tensorFuncs).
|
||||
// @param num_workers The number of worker threads.
|
||||
// @param op_connector_size The size of each queue in the connector.
|
||||
MapOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_funcs, int32_t num_workers, int32_t op_connector_size,
|
||||
bool perf_mode);
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_funcs, const std::vector<std::string> &columns_order,
|
||||
int32_t num_workers, int32_t op_connector_size, bool perf_mode);
|
||||
|
||||
// Destructor
|
||||
~MapOp() = default;
|
||||
|
@ -181,6 +190,10 @@ class MapOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MapOp"; }
|
||||
|
||||
// Columns order getter
|
||||
// @return The post map columns order
|
||||
std::vector<std::string> const &ColumnsOrder() const { return columns_order_; }
|
||||
|
||||
private:
|
||||
// Local queues where worker threads can pop from.
|
||||
// Popping directly from the Connector can block if the previous designated threads haven't pop.
|
||||
|
@ -202,6 +215,9 @@ class MapOp : public ParallelOp {
|
|||
// Indices of the columns to process.
|
||||
std::vector<size_t> to_process_indices_;
|
||||
|
||||
// Variable to store the column_order of all columns post tensorOps
|
||||
std::vector<std::string> columns_order_;
|
||||
|
||||
// Performance mode is when the main thread creates local queues, pulls databuffers from the previous
|
||||
// op's Connector and distributes them to the local queues. Workers pull from the local queues.
|
||||
// If this flag is false, each worker pulls directly from the Connector. This use less resources
|
||||
|
|
|
@ -32,7 +32,11 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
ClueOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_num_samples_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_shuffle_global_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -63,8 +67,8 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
|
|||
|
||||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
|
||||
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
|
||||
builder_device_id_);
|
||||
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_shuffle_global_,
|
||||
builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
*op = std::move(clue_op);
|
||||
|
||||
|
@ -84,7 +88,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
|
|||
|
||||
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
bool shuffle_files, bool shuffle_global, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
|
@ -95,6 +99,7 @@ ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples
|
|||
load_jagged_connector_(true),
|
||||
cols_to_keyword_(cols_to_keyword),
|
||||
shuffle_files_(shuffle_files),
|
||||
shuffle_global_(shuffle_global),
|
||||
finished_reading_dataset_(false),
|
||||
num_devices_(num_device),
|
||||
device_id_(device_id),
|
||||
|
|
|
@ -104,6 +104,13 @@ class ClueOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleGlobal(bool shuffle_global) {
|
||||
builder_shuffle_global_ = shuffle_global;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
|
@ -132,13 +139,15 @@ class ClueOp : public ParallelOp {
|
|||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_clue_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
bool builder_shuffle_global_;
|
||||
std::map<std::string, std::string> builder_cols_to_keyword_;
|
||||
};
|
||||
|
||||
// Constructor of ClueOp
|
||||
// @param shuffle_global - whether or not to shuffle the entire dataset.
|
||||
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~ClueOp() = default;
|
||||
|
@ -169,6 +178,14 @@ class ClueOp : public ParallelOp {
|
|||
// @return Status - the error coed returned.
|
||||
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);
|
||||
|
||||
// File names getter
|
||||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return clue_files_list_; }
|
||||
|
||||
// Global shuffle flag getter
|
||||
// @return Bool - whether this Op requires global shuffle
|
||||
bool RequireGlobalShuffle() { return shuffle_global_; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
@ -248,6 +265,7 @@ class ClueOp : public ParallelOp {
|
|||
|
||||
int32_t device_id_;
|
||||
bool shuffle_files_;
|
||||
bool shuffle_global_;
|
||||
bool finished_reading_dataset_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
|
|
|
@ -33,7 +33,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TextFileOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_total_rows_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_shuffle_global_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
|
||||
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
|
||||
builder_num_devices_, builder_device_id_);
|
||||
builder_shuffle_global_, builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
*op = std::move(text_file_op);
|
||||
|
||||
|
@ -73,7 +77,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
int32_t op_connector_size, bool shuffle_files, bool shuffle_global, int32_t num_device,
|
||||
int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
|
@ -81,6 +86,7 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot
|
|||
total_rows_(total_rows),
|
||||
text_files_list_(std::move(text_files_list)),
|
||||
shuffle_files_(shuffle_files),
|
||||
shuffle_global_(shuffle_global),
|
||||
data_schema_(std::move(schema)),
|
||||
all_num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
|
|
|
@ -105,6 +105,13 @@ class TextFileOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleGlobal(bool shuffle_global) {
|
||||
builder_shuffle_global_ = shuffle_global;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetTotalRows(int64_t total_rows) {
|
||||
|
@ -122,6 +129,7 @@ class TextFileOp : public ParallelOp {
|
|||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_text_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
bool builder_shuffle_global_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
};
|
||||
|
||||
|
@ -135,10 +143,11 @@ class TextFileOp : public ParallelOp {
|
|||
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
|
||||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param shuffle_global - whether or not to shuffle the entire dataset.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~TextFileOp() = default;
|
||||
|
@ -173,6 +182,14 @@ class TextFileOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "TextFileOp"; }
|
||||
|
||||
// File names getter
|
||||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return text_files_list_; }
|
||||
|
||||
// Global shuffle flag getter
|
||||
// @return Bool - whether this Op requires global shuffle
|
||||
bool RequireGlobalShuffle() { return shuffle_global_; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
@ -253,6 +270,7 @@ class TextFileOp : public ParallelOp {
|
|||
int64_t total_rows_;
|
||||
std::vector<std::string> text_files_list_;
|
||||
bool shuffle_files_;
|
||||
bool shuffle_global_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_rows_per_shard_;
|
||||
|
|
|
@ -56,6 +56,7 @@ TFReaderOp::Builder::Builder()
|
|||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
|
||||
builder_shuffle_files_ = false;
|
||||
builder_shuffle_global_ = false;
|
||||
builder_data_schema_ = std::make_unique<DataSchema>();
|
||||
}
|
||||
|
||||
|
@ -126,7 +127,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
|
|||
std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>(
|
||||
builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_,
|
||||
builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_,
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_);
|
||||
builder_shuffle_files_, builder_shuffle_global_, builder_num_devices_, builder_device_id_,
|
||||
builder_equal_rows_per_shard_);
|
||||
|
||||
RETURN_IF_NOT_OK(new_tf_reader_op->Init());
|
||||
*out_tf_reader_op = std::move(new_tf_reader_op);
|
||||
|
@ -136,8 +138,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
|
|||
TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
|
||||
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
|
||||
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
|
||||
int32_t device_id, bool equal_rows_per_shard)
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, bool shuffle_global,
|
||||
int32_t num_device, int32_t device_id, bool equal_rows_per_shard)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
|
@ -147,6 +149,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|||
columns_to_load_(std::move(columns_to_load)),
|
||||
finished_reading_dataset_(false),
|
||||
shuffle_files_(shuffle_files),
|
||||
shuffle_global_(shuffle_global),
|
||||
data_schema_(std::move(data_schema)),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
load_io_block_queue_(true),
|
||||
|
@ -172,7 +175,8 @@ void TFReaderOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Then show any custom derived-internal stuff
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_
|
||||
<< "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
|
||||
<< "\nDataset files list:\n";
|
||||
<< "\nShuffle global: " << ((shuffle_global_) ? "yes" : "no")
|
||||
<< "\nDataset files list: Size: " << dataset_files_list_.size() << "\n";
|
||||
for (int i = 0; i < dataset_files_list_.size(); ++i) {
|
||||
out << " " << dataset_files_list_[i];
|
||||
}
|
||||
|
@ -217,7 +221,6 @@ Status TFReaderOp::Init() {
|
|||
// temporary: make size large enough to hold all files + EOE to avoid hangs
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1;
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
dataset_files_list_.clear(); // no longer need the original list of files
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -146,6 +146,13 @@ class TFReaderOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleGlobal(bool shuffle_global) {
|
||||
builder_shuffle_global_ = shuffle_global;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShardEqualRows(bool shard_equal_rows) {
|
||||
|
@ -165,6 +172,7 @@ class TFReaderOp : public ParallelOp {
|
|||
std::vector<std::string> builder_dataset_files_list_;
|
||||
std::vector<std::string> builder_columns_to_load_;
|
||||
bool builder_shuffle_files_;
|
||||
bool builder_shuffle_global_;
|
||||
bool builder_equal_rows_per_shard_;
|
||||
};
|
||||
|
||||
|
@ -179,11 +187,12 @@ class TFReaderOp : public ParallelOp {
|
|||
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
|
||||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param shuffle_global - whether or not to shuffle the entire dataset.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
|
||||
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
|
||||
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
|
||||
bool shuffle_global, int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
|
||||
|
||||
// Default destructor
|
||||
~TFReaderOp() = default;
|
||||
|
@ -232,6 +241,14 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "TFReaderOp"; }
|
||||
|
||||
// File names getter
|
||||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return dataset_files_list_; }
|
||||
|
||||
// Global shuffle flag getter
|
||||
// @return Bool - whether this Op requires global shuffle
|
||||
bool RequireGlobalShuffle() { return shuffle_global_; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
@ -372,6 +389,7 @@ class TFReaderOp : public ParallelOp {
|
|||
std::vector<std::string> columns_to_load_;
|
||||
bool finished_reading_dataset_;
|
||||
bool shuffle_files_;
|
||||
bool shuffle_global_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
bool load_io_block_queue_;
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/opt/pre/map_column_reorder.h"
|
||||
#include "dataset/engine/opt/pre/global_shuffle.h"
|
||||
#include "dataset/engine/perf/profiling.h"
|
||||
#include "dataset/engine/perf/monitor.h"
|
||||
|
||||
|
@ -79,8 +81,6 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr<DatasetOp> &op) {
|
|||
// Then add it as the root.
|
||||
root_ = op;
|
||||
|
||||
// The tree has an assigned root now and it's ready to be prepared.
|
||||
tree_state_ = kDeTStatePrepare;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -207,9 +207,24 @@ Status ExecutionTree::Prepare() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); }
|
||||
Status ExecutionTree::PrepareTreePreAction() {
|
||||
bool modified = false;
|
||||
std::vector<Pass *> pre_actions;
|
||||
// Construct pre actions
|
||||
pre_actions.push_back(new MapColumnReorder());
|
||||
pre_actions.push_back(new GlobalShufflePass());
|
||||
// Apply pre action passes
|
||||
for (auto &pass : pre_actions) {
|
||||
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); }
|
||||
Status ExecutionTree::PrepareTreePostAction() {
|
||||
// The tree is ready to be prepared.
|
||||
tree_state_ = kDeTStatePrepare;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::Optimize() {
|
||||
// auto pp = new PrinterPass();
|
||||
|
|
|
@ -2,5 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
pre/map_column_reorder.cc
|
||||
pre/global_shuffle.cc
|
||||
util/printer_pass.cc
|
||||
)
|
|
@ -37,10 +37,18 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Driver method for TreePass
|
||||
Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); }
|
||||
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
||||
if (!tree || !modified) {
|
||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
||||
}
|
||||
return this->RunOnTree(tree, modified);
|
||||
}
|
||||
|
||||
// Driver method for NodePass
|
||||
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
||||
if (!tree || !modified) {
|
||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
||||
}
|
||||
std::shared_ptr<DatasetOp> root = tree->root();
|
||||
if (traversalOrder_ == Order::DFS) {
|
||||
// DFS
|
||||
|
|
|
@ -57,10 +57,10 @@ class ImageFolderOp;
|
|||
// The actual implementation of the passes will be derived from here.
|
||||
class Pass : public std::enable_shared_from_this<Pass> {
|
||||
public:
|
||||
// Run the transformation pass again the execution tree.
|
||||
// Run the transformation pass against the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
|
||||
};
|
||||
|
||||
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "dataset/engine/opt/pre/global_shuffle.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status GlobalShufflePass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
std::vector<std::shared_ptr<TFReaderOp>> tf_readers;
|
||||
std::vector<std::shared_ptr<TextFileOp>> text_files;
|
||||
std::vector<std::shared_ptr<ClueOp>> clues;
|
||||
|
||||
// Pass 1, search for all sources which requires global shuffle
|
||||
for (auto &op : *tree) {
|
||||
if (auto ptr = std::dynamic_pointer_cast<TFReaderOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
tf_readers.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (auto ptr = std::dynamic_pointer_cast<TextFileOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
text_files.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (auto ptr = std::dynamic_pointer_cast<ClueOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
clues.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2, insert shuffle nodes
|
||||
// The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes .
|
||||
for (auto node : tf_readers) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
TFReaderOp::CountTotalRows(&total_rows, node->FileNames(), 8, true);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
for (auto node : text_files) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
TextFileOp::CountAllFileRows(node->FileNames(), &total_rows);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
for (auto node : clues) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
ClueOp::CountAllFileRows(node->FileNames(), &total_rows);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle.
|
||||
// Example:
|
||||
// Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch
|
||||
// Output Tree: TFReader -> Shuffle -> Batch
|
||||
class GlobalShufflePass : public TreePass {
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/engine/opt/pre/map_column_reorder.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/engine/datasetops/project_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status MapColumnReorder::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
std::vector<std::shared_ptr<MapOp>> to_process;
|
||||
|
||||
// Pass 1, search for all MapOp with column orders
|
||||
for (auto &op : *tree) {
|
||||
if (auto mapOp = std::dynamic_pointer_cast<MapOp>(op.shared_from_this())) {
|
||||
if (mapOp->ColumnsOrder().size() != 0) {
|
||||
to_process.push_back(mapOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2, insert nodes for all MapOp
|
||||
for (auto node : to_process) {
|
||||
std::shared_ptr<ProjectOp::Builder> builder = std::make_shared<ProjectOp::Builder>(node->ColumnsOrder());
|
||||
std::shared_ptr<ProjectOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder.
|
||||
// Example:
|
||||
// Input Tree: TFReader -> MapOp(with col_order) -> Batch
|
||||
// Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch
|
||||
class MapColumnReorder : public TreePass {
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
|
@ -1910,6 +1910,7 @@ class MapDataset(DatasetOp):
|
|||
args["input_columns"] = self.input_columns
|
||||
args["operations"] = self.operations
|
||||
args["output_columns"] = self.output_columns
|
||||
args["columns_order"] = self.columns_order
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -3299,6 +3300,7 @@ class TFRecordDataset(SourceDataset):
|
|||
args["num_samples"] = self.num_samples
|
||||
if self.shuffle_files is not None:
|
||||
args["shuffle_files"] = self.shuffle_files
|
||||
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
|
@ -4607,6 +4609,7 @@ class CLUEDataset(SourceDataset):
|
|||
args["num_samples"] = self.num_samples
|
||||
if self.shuffle_files is not None:
|
||||
args["shuffle_files"] = self.shuffle_files
|
||||
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
|
@ -4697,6 +4700,7 @@ class TextFileDataset(SourceDataset):
|
|||
args["num_samples"] = self.num_samples
|
||||
if self.shuffle_files is not None:
|
||||
args["shuffle_files"] = self.shuffle_files
|
||||
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
|
|
|
@ -49,33 +49,13 @@ def alter_tree(node):
|
|||
|
||||
|
||||
def _alter_node(node):
|
||||
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
|
||||
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \
|
||||
and node.shuffle_level == de.Shuffle.GLOBAL:
|
||||
# Remove the connection between the parent's node to the current node because we are inserting a node.
|
||||
if node.output:
|
||||
node.output.pop()
|
||||
# Perform a fast scan for average rows per file
|
||||
if isinstance(node, de.TFRecordDataset):
|
||||
avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
|
||||
else:
|
||||
avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files)
|
||||
|
||||
# Shuffle between 4 files with a minimum size of 10000 rows
|
||||
new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000))
|
||||
return new_shuffle
|
||||
|
||||
"""DEPRECATED"""
|
||||
# Please check ccsrc/dataset/engine/opt for tree transformation.
|
||||
if isinstance(node, de.MapDataset):
|
||||
if node.python_multiprocessing:
|
||||
# Bootstrap can only be performed on a copy of the original dataset node.
|
||||
# Bootstrap on original dataset node will make all iterators share the same process pool
|
||||
node.iterator_bootstrap()
|
||||
if node.columns_order is not None:
|
||||
# Remove the connection between the parent's node to the current node because we are inserting a node.
|
||||
if node.output:
|
||||
node.output.pop()
|
||||
|
||||
return node.project(node.columns_order)
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
|
|||
ASSERT_NE(my_tfreader_op, nullptr);
|
||||
parent_op->AddChild(std::move(my_tfreader_op));
|
||||
MS_LOG(INFO) << parent_op;
|
||||
my_tree->AssignRoot(parent_op);
|
||||
my_tree->Prepare();
|
||||
|
||||
RepeatOp RepeatOpOp();
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def test_map_reorder_pass_0():
|
||||
def generator_mc(maxid=1):
|
||||
for _ in range(maxid):
|
||||
yield (np.array([0]), np.array([1]))
|
||||
|
||||
# Generator -> Map
|
||||
data0 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
|
||||
|
||||
data0 = data0.map(input_columns="col0", output_columns="out", columns_order=["col1", "out"],
|
||||
operations=(lambda x: x))
|
||||
|
||||
for item in data0.create_tuple_iterator(): # each data is a dictionary
|
||||
assert item == [np.array(1), np.array(0)]
|
||||
|
||||
|
||||
def test_map_reorder_pass_1():
|
||||
def generator_mc(maxid=1):
|
||||
for _ in range(maxid):
|
||||
yield (np.array([0]), np.array([1]), np.array([2]))
|
||||
|
||||
# Three map and zip
|
||||
data0 = ds.GeneratorDataset(generator_mc, ["a0", "a1", "a2"])
|
||||
data0 = data0.map(input_columns="a0", columns_order=["a2", "a1", "a0"], operations=(lambda x: x))
|
||||
data1 = ds.GeneratorDataset(generator_mc, ["b0", "b1", "b2"])
|
||||
data1 = data1.map(input_columns="b0", columns_order=["b1", "b2", "b0"], operations=(lambda x: x))
|
||||
data2 = ds.zip((data0, data1))
|
||||
data2 = data2.map(input_columns="a0", columns_order=["b2", "a2", "b1", "a1", "b0", "a0"], operations=(lambda x: x))
|
||||
|
||||
for item in data2.create_tuple_iterator():
|
||||
assert item == [np.array(2), np.array(2), np.array(1), np.array(1), np.array(0), np.array(0)]
|
||||
|
||||
|
||||
def test_global_shuffle_pass():
|
||||
|
||||
FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
|
||||
ds.config.set_seed(1)
|
||||
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
ds.config.set_seed(1)
|
||||
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
|
||||
data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
ds.config.set_seed(1)
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_map_reorder_pass_0()
|
||||
test_map_reorder_pass_1()
|
||||
test_global_shuffle_pass()
|
Loading…
Reference in New Issue