forked from mindspore-Ecosystem/mindspore
!1272 [Dataset] MindData Tree Optimizer Infrastructure
Merge pull request !1272 from JunhanHu/minddata_opt
This commit is contained in:
commit
93e7c97a96
|
@ -66,6 +66,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||
$<TARGET_OBJECTS:engine-datasetops>
|
||||
$<TARGET_OBJECTS:engine-opt>
|
||||
$<TARGET_OBJECTS:engine>
|
||||
)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_subdirectory(datasetops)
|
||||
add_subdirectory(opt)
|
||||
if (ENABLE_TDTQUE)
|
||||
add_subdirectory(tdt)
|
||||
endif ()
|
||||
|
@ -14,7 +15,7 @@ add_library(engine OBJECT
|
|||
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
|
||||
if (ENABLE_TDTQUE)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
|
||||
else()
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
|
||||
endif ()
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "dataset/core/pybind_support.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
using float16 = Eigen::half;
|
||||
|
||||
|
@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status BatchOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -192,6 +192,12 @@ class BatchOp : public ParallelOp {
|
|||
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
float pad_val);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
|
||||
// it is only meant to be called by PadTensor.
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOp::Accept(NodePass *p, bool *modified) {
|
||||
// DatasetOp is the base class of visitor target.
|
||||
// This method will only be called if its derived class does not implement one.
|
||||
return p->RunOnNode(shared_from_this(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,8 @@ class ExecutionTree;
|
|||
|
||||
class DataBuffer;
|
||||
|
||||
class NodePass;
|
||||
|
||||
// The base class DatasetOp is the main tree node. It is an abstract class, so
|
||||
// the actual implementation of the operators will be derived from here.
|
||||
class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||
|
@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @return - the column name map as a string
|
||||
std::string ColumnNameMapAsString() const;
|
||||
|
||||
// Children Getter
|
||||
// @return Vector or Children
|
||||
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
|
||||
|
||||
// Base method for NodePass visit.
|
||||
// Subclass needs to override this if it requires special node visit access.
|
||||
// Check "dataset/engine/opt/pass.h" for more details.
|
||||
// @return Statue of the node visit
|
||||
virtual Status Accept(NodePass *p, bool *modified);
|
||||
|
||||
protected:
|
||||
// Adds a parent operator to this operator
|
||||
// @notes External callers do not have access to this function.
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "tdt/tsd_client.h"
|
||||
|
@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
|
|||
out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp {
|
|||
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// Name: checkExceptions(DataBuffer);
|
||||
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
}
|
||||
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status FilterOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
|
|||
// @param show_all A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
py::function predicate_func_;
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
|
|||
column_name_id_map_ = final_col_name_id_map;
|
||||
}
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status MapOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -171,6 +171,12 @@ class MapOp : public ParallelOp {
|
|||
// @return the number of threads consuming data from previous op's output Connector.
|
||||
int32_t num_consumers() const override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// Local queues where worker threads can pop from.
|
||||
// Popping directly from the Connector can block if the previous designated threads haven't pop.
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) {
|
|||
}
|
||||
|
||||
Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ProjectOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp {
|
|||
// @return Status - The error code returned.
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_to_project_;
|
||||
std::vector<int32_t> projected_column_indices_;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) {
|
|||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status RenameOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -110,6 +110,12 @@ class RenameOp : public PipelineOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
protected:
|
||||
// Rename core functionality
|
||||
Status RenameColumns();
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "dataset/engine/datasetops/repeat_op.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const {
|
|||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp {
|
|||
// @param workerId - The worker id
|
||||
int32_t num_producers() const override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
int32_t max_repeats_; // The number of repeats that the user requested
|
||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
|
@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
|
|||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp {
|
|||
// @return Status - The error code return
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// Private function to add a new row to the shuffle buffer.
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "dataset/engine/datasetops/skip_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) {
|
|||
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status SkipOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,6 +74,12 @@ class SkipOp : public PipelineOp {
|
|||
// @param worker_id - The worker id
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
int32_t max_skips_; // The number of skips that the user requested
|
||||
int32_t skip_count_; // A counter for the current number of executed skips
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -250,5 +251,11 @@ Status GeneratorOp::Reset() {
|
|||
wp_.Set();
|
||||
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp {
|
|||
// @return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
|
|||
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
|
||||
int64_t dev_id = 0, int64_t num_dev = 1);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp {
|
|||
|
||||
Status SetColumnsBlob();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/util/path.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/random.h"
|
||||
|
@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
|
|||
|
||||
return rows_read;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp {
|
|||
static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1,
|
||||
bool estimate = false);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "dataset/engine/datasetops/take_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() {
|
|||
tree_->AddToRepeatStack(shared_from_this());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status TakeOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -84,6 +84,12 @@ class TakeOp : public PipelineOp {
|
|||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
int32_t take_count_; // A counter for the current number of executed takes
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "dataset/core/constants.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) {
|
|||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ZipOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -104,6 +104,12 @@ class ZipOp : public PipelineOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status prepare(TensorQTable *const table);
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
||||
#include "dataset/engine/opt/util/printer_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
|
@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// The driver of the prepare phase of the execution tree.
|
||||
// Prepare phase consists of three sub phases
|
||||
//
|
||||
// 1. PrepareTreePreAction()
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// For example, CacheOp Insertion
|
||||
//
|
||||
// 2. Optimize()
|
||||
// Optimization transformation/action, optional
|
||||
// For example, MapOp Fusion
|
||||
//
|
||||
// 3. PrepareTreePostAction()
|
||||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
Status ExecutionTree::Prepare() {
|
||||
// Pre optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
|
||||
|
||||
// Optimization transformation
|
||||
RETURN_IF_NOT_OK(this->Optimize());
|
||||
|
||||
// Post optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
|
||||
|
||||
// Existing transformation implementation, will be removed later
|
||||
RETURN_IF_NOT_OK(this->PrepareDeprecated());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); }
|
||||
|
||||
Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); }
|
||||
|
||||
Status ExecutionTree::Optimize() {
|
||||
// auto pp = new PrinterPass();
|
||||
// bool modified = false;
|
||||
// pp->Run(this, &modified);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||
// it ready for execution.
|
||||
Status ExecutionTree::Prepare() {
|
||||
//
|
||||
// This driver is deprecated.
|
||||
Status ExecutionTree::PrepareDeprecated() {
|
||||
// Tree must be in pending prepare state before we can assign root to it
|
||||
if (tree_state_ != kDeTStatePrepare) {
|
||||
std::string err_msg =
|
||||
|
|
|
@ -152,11 +152,41 @@ class ExecutionTree {
|
|||
// @return the prepare flags
|
||||
uint32_t PrepareFlags() const { return prepare_flags_; }
|
||||
|
||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// The driver of the prepare phase of the execution tree.
|
||||
// Prepare phase consists of three sub phases
|
||||
//
|
||||
// 1. PrepareTreePreAction()
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// For example, CacheOp Insertion
|
||||
//
|
||||
// 2. Optimize()
|
||||
// Optimization transformation/action, optional
|
||||
// For example, MapOp Fusion
|
||||
//
|
||||
// 3. PrepareTreePostAction()
|
||||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
Status Prepare();
|
||||
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// @return Status - The error code return
|
||||
Status PrepareTreePreAction();
|
||||
|
||||
// Compulsory transformation/action post optimization.
|
||||
// @return Status - The error code return
|
||||
Status PrepareTreePostAction();
|
||||
|
||||
// Optimization transformation/action, optional.
|
||||
// @return Status - The error code return
|
||||
Status Optimize();
|
||||
|
||||
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||
// it ready for execution.
|
||||
// @return Status - The error code return
|
||||
Status Prepare();
|
||||
Status PrepareDeprecated();
|
||||
|
||||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||
// node actions during a tree walk.
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
util/printer_pass.cc
|
||||
)
|
|
@ -0,0 +1,157 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/engine/datasetops/batch_op.h"
|
||||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/engine/datasetops/project_op.h"
|
||||
#include "dataset/engine/datasetops/rename_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "dataset/engine/datasetops/repeat_op.h"
|
||||
#include "dataset/engine/datasetops/skip_op.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/take_op.h"
|
||||
#include "dataset/engine/datasetops/zip_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Driver method for TreePass
|
||||
Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); }
|
||||
|
||||
// Driver method for NodePass
|
||||
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
||||
std::shared_ptr<DatasetOp> root = tree->root();
|
||||
if (traversalOrder_ == Order::DFS) {
|
||||
// DFS
|
||||
return DFSNodeVisit(root, modified);
|
||||
} else if (traversalOrder_ == Order::BFS) {
|
||||
// BFS
|
||||
return BFSNodeVisit(root, modified);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to perform DFS visit
|
||||
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||
for (const auto &c : node->Children()) {
|
||||
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
|
||||
}
|
||||
return node->Accept(this, modified);
|
||||
}
|
||||
|
||||
// Helper function to perform BFS visit
|
||||
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) {
|
||||
// Initialize bfs queue with root
|
||||
std::queue<std::shared_ptr<DatasetOp>> bfsQueue;
|
||||
bfsQueue.push(root);
|
||||
|
||||
// BFS loop
|
||||
while (!bfsQueue.empty()) {
|
||||
// Pop the front of the bfs queue
|
||||
auto curNode = bfsQueue.front();
|
||||
bfsQueue.pop();
|
||||
|
||||
// Run node pass
|
||||
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
|
||||
|
||||
// Push children into bfs queue
|
||||
for (const auto &c : curNode->Children()) {
|
||||
bfsQueue.push(c);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class BatchOp;
|
||||
|
||||
class MapOp;
|
||||
|
||||
class ProjectOp;
|
||||
|
||||
class RenameOp;
|
||||
|
||||
class FilterOp;
|
||||
|
||||
class SkipOp;
|
||||
|
||||
class ShuffleOp;
|
||||
|
||||
class GeneratorOp;
|
||||
|
||||
class MindRecordOp;
|
||||
|
||||
class TFReaderOp;
|
||||
|
||||
class TakeOp;
|
||||
|
||||
class ZipOp;
|
||||
|
||||
class DeviceQueueOp;
|
||||
|
||||
class ImageFolderOp;
|
||||
|
||||
// The base class Pass is the basic unit of tree transformation.
|
||||
// The actual implementation of the passes will be derived from here.
|
||||
class Pass : public std::enable_shared_from_this<Pass> {
|
||||
public:
|
||||
// Run the transformation pass again the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||
};
|
||||
|
||||
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
||||
class TreePass : public Pass {
|
||||
public:
|
||||
// Run the transformation pass against the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||
|
||||
// Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
// @return Status - The error code return
|
||||
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||
};
|
||||
|
||||
// NodePass is a basic Pass class which performs transformation on Node visiting.
|
||||
// NodePass implements Visitor design pattern.
|
||||
class NodePass : public Pass {
|
||||
public:
|
||||
// Tree traversal order
|
||||
enum Order { DFS, BFS };
|
||||
|
||||
// Constructor
|
||||
// Default DFS traversal
|
||||
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
|
||||
|
||||
// Run the transformation pass against the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||
|
||||
// Derived classes may implement the runOnNode function to implement node level tree transformation.
|
||||
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
// @return Status - The error code return
|
||||
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
// Visit methods to be overridden.
|
||||
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
|
||||
// of its own type and override "Accept" from DatasetOp.
|
||||
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
||||
|
||||
private:
|
||||
// Helper function to perform DFS visit
|
||||
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||
|
||||
// Helper function to perform BFS visit
|
||||
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
|
||||
|
||||
// Tree traversal order of the NodePass
|
||||
Order traversalOrder_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_H_
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/util/printer_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting DatasetOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting BatchOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting MapOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ProjectOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting RenameOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting FilterOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting SkipOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ShuffleOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting GeneratorOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting MindRecordOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting TFReaderOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting TakeOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ZipOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting DeviceQueueOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ImageFolderOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||
#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class PrinterPass : public NodePass {
|
||||
public:
|
||||
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
# Generate 1d int numpy array from 0 - 63
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_case_0():
|
||||
"""
|
||||
Test 1D Generator
|
||||
"""
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
data1 = data1.shuffle(2)
|
||||
|
||||
data1 = data1.map(["data"], operations=(lambda x : x))
|
||||
|
||||
data1 = data1.batch(2)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_case_0()
|
Loading…
Reference in New Issue