MindData optimizer infrastructure.
This commit is contained in:
parent
6cbde2b3bb
commit
f44d213503
|
@ -65,6 +65,7 @@ set(submodules
|
||||||
$<TARGET_OBJECTS:engine-datasetops-source>
|
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||||
$<TARGET_OBJECTS:engine-datasetops>
|
$<TARGET_OBJECTS:engine-datasetops>
|
||||||
|
$<TARGET_OBJECTS:engine-opt>
|
||||||
$<TARGET_OBJECTS:engine>
|
$<TARGET_OBJECTS:engine>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
add_subdirectory(datasetops)
|
add_subdirectory(datasetops)
|
||||||
|
add_subdirectory(opt)
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
add_subdirectory(tdt)
|
add_subdirectory(tdt)
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -14,7 +15,7 @@ add_library(engine OBJECT
|
||||||
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt)
|
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
|
||||||
else()
|
else()
|
||||||
add_dependencies(engine engine-datasetops engine-datasetops-source)
|
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/core/pybind_support.h"
|
#include "dataset/core/pybind_support.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
using float16 = Eigen::half;
|
using float16 = Eigen::half;
|
||||||
|
|
||||||
|
@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status BatchOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -192,6 +192,12 @@ class BatchOp : public ParallelOp {
|
||||||
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||||
float pad_val);
|
float pad_val);
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
|
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
|
||||||
// it is only meant to be called by PadTensor.
|
// it is only meant to be called by PadTensor.
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() {
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DatasetOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// DatasetOp is the base class of visitor target.
|
||||||
|
// This method will only be called if its derived class does not implement one.
|
||||||
|
return p->RunOnNode(shared_from_this(), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,6 +32,8 @@ class ExecutionTree;
|
||||||
|
|
||||||
class DataBuffer;
|
class DataBuffer;
|
||||||
|
|
||||||
|
class NodePass;
|
||||||
|
|
||||||
// The base class DatasetOp is the main tree node. It is an abstract class, so
|
// The base class DatasetOp is the main tree node. It is an abstract class, so
|
||||||
// the actual implementation of the operators will be derived from here.
|
// the actual implementation of the operators will be derived from here.
|
||||||
class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
|
@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
// @return - the column name map as a string
|
// @return - the column name map as a string
|
||||||
std::string ColumnNameMapAsString() const;
|
std::string ColumnNameMapAsString() const;
|
||||||
|
|
||||||
|
// Children Getter
|
||||||
|
// @return Vector or Children
|
||||||
|
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
|
||||||
|
|
||||||
|
// Base method for NodePass visit.
|
||||||
|
// Subclass needs to override this if it requires special node visit access.
|
||||||
|
// Check "dataset/engine/opt/pass.h" for more details.
|
||||||
|
// @return Statue of the node visit
|
||||||
|
virtual Status Accept(NodePass *p, bool *modified);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Adds a parent operator to this operator
|
// Adds a parent operator to this operator
|
||||||
// @notes External callers do not have access to this function.
|
// @notes External callers do not have access to this function.
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "dataset/engine/dataset_iterator.h"
|
#include "dataset/engine/dataset_iterator.h"
|
||||||
#include "dataset/util/status.h"
|
#include "dataset/util/status.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
#include "tdt/tsd_client.h"
|
#include "tdt/tsd_client.h"
|
||||||
|
@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
|
||||||
out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n";
|
out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
|
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Name: checkExceptions(DataBuffer);
|
// Name: checkExceptions(DataBuffer);
|
||||||
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
|
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/kernels/tensor_op.h"
|
#include "dataset/kernels/tensor_op.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
||||||
}
|
}
|
||||||
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
|
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status FilterOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
|
||||||
// @param show_all A bool to control if you want to show all info or just a summary.
|
// @param show_all A bool to control if you want to show all info or just a summary.
|
||||||
void Print(std::ostream &out, bool show_all) const override;
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// predicate_func python callable which returns a boolean value.
|
// predicate_func python callable which returns a boolean value.
|
||||||
py::function predicate_func_;
|
py::function predicate_func_;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/kernels/tensor_op.h"
|
#include "dataset/kernels/tensor_op.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
|
||||||
column_name_id_map_ = final_col_name_id_map;
|
column_name_id_map_ = final_col_name_id_map;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status MapOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -171,6 +171,12 @@ class MapOp : public ParallelOp {
|
||||||
// @return the number of threads consuming data from previous op's output Connector.
|
// @return the number of threads consuming data from previous op's output Connector.
|
||||||
int32_t num_consumers() const override;
|
int32_t num_consumers() const override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Local queues where worker threads can pop from.
|
// Local queues where worker threads can pop from.
|
||||||
// Popping directly from the Connector can block if the previous designated threads haven't pop.
|
// Popping directly from the Connector can block if the previous designated threads haven't pop.
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
|
Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status ProjectOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp {
|
||||||
// @return Status - The error code returned.
|
// @return Status - The error code returned.
|
||||||
Status EofReceived(int32_t worker_id) override;
|
Status EofReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> columns_to_project_;
|
std::vector<std::string> columns_to_project_;
|
||||||
std::vector<int32_t> projected_column_indices_;
|
std::vector<int32_t> projected_column_indices_;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "dataset/core/global_context.h"
|
#include "dataset/core/global_context.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) {
|
||||||
state_ = OpState::kDeOpIdle;
|
state_ = OpState::kDeOpIdle;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status RenameOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -110,6 +110,12 @@ class RenameOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Rename core functionality
|
// Rename core functionality
|
||||||
Status RenameColumns();
|
Status RenameColumns();
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "dataset/engine/datasetops/repeat_op.h"
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->num_producers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp {
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_producers() const override;
|
int32_t num_producers() const override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_repeats_; // The number of repeats that the user requested
|
int32_t max_repeats_; // The number of repeats that the user requested
|
||||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "dataset/engine/dataset_iterator.h"
|
#include "dataset/engine/dataset_iterator.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "dataset/util/status.h"
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
|
||||||
state_ = OpState::kDeOpIdle;
|
state_ = OpState::kDeOpIdle;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status EoeReceived(int32_t worker_id) override;
|
Status EoeReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Private function to add a new row to the shuffle buffer.
|
// Private function to add a new row to the shuffle buffer.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/engine/datasetops/skip_op.h"
|
#include "dataset/engine/datasetops/skip_op.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) {
|
||||||
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
|
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status SkipOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -74,6 +74,12 @@ class SkipOp : public PipelineOp {
|
||||||
// @param worker_id - The worker id
|
// @param worker_id - The worker id
|
||||||
Status EofReceived(int32_t worker_id) override;
|
Status EofReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_skips_; // The number of skips that the user requested
|
int32_t max_skips_; // The number of skips that the user requested
|
||||||
int32_t skip_count_; // A counter for the current number of executed skips
|
int32_t skip_count_; // A counter for the current number of executed skips
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -250,5 +251,11 @@ Status GeneratorOp::Reset() {
|
||||||
wp_.Set();
|
wp_.Set();
|
||||||
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status Reset() override;
|
Status Reset() override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
py::function generator_function_;
|
py::function generator_function_;
|
||||||
std::vector<std::string> column_names_;
|
std::vector<std::string> column_names_;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
|
||||||
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
|
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
||||||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
|
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
|
||||||
int64_t dev_id = 0, int64_t num_dev = 1);
|
int64_t dev_id = 0, int64_t num_dev = 1);
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Initialize Sampler, calls sampler->Init() within
|
// Initialize Sampler, calls sampler->Init() within
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "dataset/engine/datasetops/dataset_op.h"
|
#include "dataset/engine/datasetops/dataset_op.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp {
|
||||||
|
|
||||||
Status SetColumnsBlob();
|
Status SetColumnsBlob();
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
#include "dataset/engine/jagged_connector.h"
|
#include "dataset/engine/jagged_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/util/path.h"
|
#include "dataset/util/path.h"
|
||||||
#include "dataset/util/queue.h"
|
#include "dataset/util/queue.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
|
@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
|
||||||
|
|
||||||
return rows_read;
|
return rows_read;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp {
|
||||||
static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1,
|
static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1,
|
||||||
bool estimate = false);
|
bool estimate = false);
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The entry point for when workers are launched.
|
// The entry point for when workers are launched.
|
||||||
// @param worker_id - the id of the worker that is executing this function.
|
// @param worker_id - the id of the worker that is executing this function.
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/engine/datasetops/take_op.h"
|
#include "dataset/engine/datasetops/take_op.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() {
|
||||||
tree_->AddToRepeatStack(shared_from_this());
|
tree_->AddToRepeatStack(shared_from_this());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status TakeOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -84,6 +84,12 @@ class TakeOp : public PipelineOp {
|
||||||
// before providing their own implementations.
|
// before providing their own implementations.
|
||||||
Status PrepareNodePostAction() override;
|
Status PrepareNodePostAction() override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_takes_; // The number of takes that the user requested
|
int32_t max_takes_; // The number of takes that the user requested
|
||||||
int32_t take_count_; // A counter for the current number of executed takes
|
int32_t take_count_; // A counter for the current number of executed takes
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "dataset/core/constants.h"
|
#include "dataset/core/constants.h"
|
||||||
#include "dataset/engine/data_buffer.h"
|
#include "dataset/engine/data_buffer.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/core/config_manager.h"
|
#include "dataset/core/config_manager.h"
|
||||||
#include "dataset/core/global_context.h"
|
#include "dataset/core/global_context.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) {
|
||||||
state_ = OpState::kDeOpIdle;
|
state_ = OpState::kDeOpIdle;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status ZipOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -104,6 +104,12 @@ class ZipOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Handles preprocessing of the main loop, used when starting new epoch
|
// Handles preprocessing of the main loop, used when starting new epoch
|
||||||
Status prepare(TensorQTable *const table);
|
Status prepare(TensorQTable *const table);
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
|
||||||
|
#include "dataset/engine/opt/util/printer_pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// Constructor
|
// Constructor
|
||||||
|
@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The driver of the prepare phase of the execution tree.
|
||||||
|
// Prepare phase consists of three sub phases
|
||||||
|
//
|
||||||
|
// 1. PrepareTreePreAction()
|
||||||
|
// Compulsory transformation/action pre optimization.
|
||||||
|
// For example, CacheOp Insertion
|
||||||
|
//
|
||||||
|
// 2. Optimize()
|
||||||
|
// Optimization transformation/action, optional
|
||||||
|
// For example, MapOp Fusion
|
||||||
|
//
|
||||||
|
// 3. PrepareTreePostAction()
|
||||||
|
// Compulsory transformation/action post optimization.
|
||||||
|
// For example, repeatOp inlining
|
||||||
|
//
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status ExecutionTree::Prepare() {
|
||||||
|
// Pre optimization compulsory transformation
|
||||||
|
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
|
||||||
|
|
||||||
|
// Optimization transformation
|
||||||
|
RETURN_IF_NOT_OK(this->Optimize());
|
||||||
|
|
||||||
|
// Post optimization compulsory transformation
|
||||||
|
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
|
||||||
|
|
||||||
|
// Existing transformation implementation, will be removed later
|
||||||
|
RETURN_IF_NOT_OK(this->PrepareDeprecated());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); }
|
||||||
|
|
||||||
|
Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); }
|
||||||
|
|
||||||
|
Status ExecutionTree::Optimize() {
|
||||||
|
// auto pp = new PrinterPass();
|
||||||
|
// bool modified = false;
|
||||||
|
// pp->Run(this, &modified);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||||
// it ready for execution.
|
// it ready for execution.
|
||||||
Status ExecutionTree::Prepare() {
|
//
|
||||||
|
// This driver is deprecated.
|
||||||
|
Status ExecutionTree::PrepareDeprecated() {
|
||||||
// Tree must be in pending prepare state before we can assign root to it
|
// Tree must be in pending prepare state before we can assign root to it
|
||||||
if (tree_state_ != kDeTStatePrepare) {
|
if (tree_state_ != kDeTStatePrepare) {
|
||||||
std::string err_msg =
|
std::string err_msg =
|
||||||
|
|
|
@ -152,11 +152,41 @@ class ExecutionTree {
|
||||||
// @return the prepare flags
|
// @return the prepare flags
|
||||||
uint32_t PrepareFlags() const { return prepare_flags_; }
|
uint32_t PrepareFlags() const { return prepare_flags_; }
|
||||||
|
|
||||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
// The driver of the prepare phase of the execution tree.
|
||||||
|
// Prepare phase consists of three sub phases
|
||||||
|
//
|
||||||
|
// 1. PrepareTreePreAction()
|
||||||
|
// Compulsory transformation/action pre optimization.
|
||||||
|
// For example, CacheOp Insertion
|
||||||
|
//
|
||||||
|
// 2. Optimize()
|
||||||
|
// Optimization transformation/action, optional
|
||||||
|
// For example, MapOp Fusion
|
||||||
|
//
|
||||||
|
// 3. PrepareTreePostAction()
|
||||||
|
// Compulsory transformation/action post optimization.
|
||||||
|
// For example, repeatOp inlining
|
||||||
|
//
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status Prepare();
|
||||||
|
|
||||||
|
// Compulsory transformation/action pre optimization.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status PrepareTreePreAction();
|
||||||
|
|
||||||
|
// Compulsory transformation/action post optimization.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status PrepareTreePostAction();
|
||||||
|
|
||||||
|
// Optimization transformation/action, optional.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status Optimize();
|
||||||
|
|
||||||
|
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||||
// it ready for execution.
|
// it ready for execution.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status Prepare();
|
Status PrepareDeprecated();
|
||||||
|
|
||||||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||||
// node actions during a tree walk.
|
// node actions during a tree walk.
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
|
add_library(engine-opt OBJECT
|
||||||
|
pass.cc
|
||||||
|
util/printer_pass.cc
|
||||||
|
)
|
|
@ -0,0 +1,160 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
#include "dataset/engine/datasetops/dataset_op.h"
|
||||||
|
#include "dataset/engine/datasetops/batch_op.h"
|
||||||
|
#include "dataset/engine/datasetops/dataset_op.h"
|
||||||
|
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||||
|
#include "dataset/engine/datasetops/map_op.h"
|
||||||
|
#include "dataset/engine/datasetops/project_op.h"
|
||||||
|
#include "dataset/engine/datasetops/rename_op.h"
|
||||||
|
#include "dataset/engine/datasetops/filter_op.h"
|
||||||
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
|
#include "dataset/engine/datasetops/skip_op.h"
|
||||||
|
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||||
|
#include "dataset/engine/datasetops/take_op.h"
|
||||||
|
#include "dataset/engine/datasetops/zip_op.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// Driver method for TreePass
|
||||||
|
Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); }
|
||||||
|
|
||||||
|
// Driver method for NodePass
|
||||||
|
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
||||||
|
std::shared_ptr<DatasetOp> root = tree->root();
|
||||||
|
if (traversalOrder_ == Order::DFS) {
|
||||||
|
// DFS
|
||||||
|
return DFSNodeVisit(root, modified);
|
||||||
|
} else if (traversalOrder_ == Order::BFS) {
|
||||||
|
// BFS
|
||||||
|
return BFSNodeVisit(root, modified);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to perform DFS visit
|
||||||
|
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||||
|
for (const auto &c : node->Children()) {
|
||||||
|
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
|
||||||
|
}
|
||||||
|
return node->Accept(this, modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to perform BFS visit
|
||||||
|
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) {
|
||||||
|
// Initialize bfs queue with root
|
||||||
|
std::queue<std::shared_ptr<DatasetOp>> bfsQueue;
|
||||||
|
bfsQueue.push(root);
|
||||||
|
|
||||||
|
// BFS loop
|
||||||
|
while (!bfsQueue.empty()) {
|
||||||
|
// Pop the front of the bfs queue
|
||||||
|
auto curNode = bfsQueue.front();
|
||||||
|
bfsQueue.pop();
|
||||||
|
|
||||||
|
// Run node pass
|
||||||
|
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
|
||||||
|
|
||||||
|
// Push children into bfs queue
|
||||||
|
for (const auto &c : curNode->Children()) {
|
||||||
|
bfsQueue.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,146 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class BatchOp;
|
||||||
|
|
||||||
|
class MapOp;
|
||||||
|
|
||||||
|
class ProjectOp;
|
||||||
|
|
||||||
|
class RenameOp;
|
||||||
|
|
||||||
|
class FilterOp;
|
||||||
|
|
||||||
|
class SkipOp;
|
||||||
|
|
||||||
|
class ShuffleOp;
|
||||||
|
|
||||||
|
class GeneratorOp;
|
||||||
|
|
||||||
|
class MindRecordOp;
|
||||||
|
|
||||||
|
class TFReaderOp;
|
||||||
|
|
||||||
|
class TakeOp;
|
||||||
|
|
||||||
|
class ZipOp;
|
||||||
|
|
||||||
|
class DeviceQueueOp;
|
||||||
|
|
||||||
|
class ImageFolderOp;
|
||||||
|
|
||||||
|
// The base class Pass is the basic unit of tree transformation.
|
||||||
|
// The actual implementation of the passes will be derived from here.
|
||||||
|
class Pass : public std::enable_shared_from_this<Pass> {
|
||||||
|
public:
|
||||||
|
// Run the transformation pass again the execution tree.
|
||||||
|
// @param tree - Pointer to the execution tree to be transformed.
|
||||||
|
// @param modified - Pointer to the modified flag,
|
||||||
|
virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
||||||
|
class TreePass : public Pass {
|
||||||
|
public:
|
||||||
|
// Run the transformation pass against the execution tree.
|
||||||
|
// @param tree - Pointer to the execution tree to be transformed.
|
||||||
|
// @param modified - Pointer to the modified flag,
|
||||||
|
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||||
|
|
||||||
|
// Derived classes may implement the runOnTree function to implement tree transformation.
|
||||||
|
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// NodePass is a basic Pass class which performs transformation on Node visiting.
|
||||||
|
// NodePass implements Visitor design pattern.
|
||||||
|
class NodePass : public Pass {
|
||||||
|
public:
|
||||||
|
// Tree traversal order
|
||||||
|
enum Order { DFS, BFS };
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
// Default DFS traversal
|
||||||
|
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
|
||||||
|
|
||||||
|
// Run the transformation pass against the execution tree.
|
||||||
|
// @param tree - Pointer to the execution tree to be transformed.
|
||||||
|
// @param modified - Pointer to the modified flag,
|
||||||
|
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||||
|
|
||||||
|
// Derived classes may implement the runOnNode function to implement node level tree transformation.
|
||||||
|
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||||
|
|
||||||
|
// Visit methods to be overridden.
|
||||||
|
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
|
||||||
|
// of its own type and override "Accept" from DatasetOp.
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Helper function to perform DFS visit
|
||||||
|
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||||
|
|
||||||
|
// Helper function to perform BFS visit
|
||||||
|
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
|
||||||
|
|
||||||
|
// Tree traversal order of the NodePass
|
||||||
|
Order traversalOrder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_H_
|
|
@ -0,0 +1,111 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/util/printer_pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting DatasetOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting BatchOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting MapOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting ProjectOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting RenameOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting FilterOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting SkipOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting ShuffleOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting GeneratorOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting MindRecordOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting TFReaderOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting TakeOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting ZipOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting DeviceQueueOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
std::cout << "Visiting ImageFolderOp" << '\n';
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class PrinterPass : public NodePass {
|
||||||
|
public:
|
||||||
|
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
# Generate 1d int numpy array from 0 - 63
|
||||||
|
def generator_1d():
|
||||||
|
for i in range(64):
|
||||||
|
yield (np.array([i]),)
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_0():
|
||||||
|
"""
|
||||||
|
Test 1D Generator
|
||||||
|
"""
|
||||||
|
|
||||||
|
# apply dataset operations
|
||||||
|
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||||
|
|
||||||
|
data1 = data1.shuffle(2)
|
||||||
|
|
||||||
|
data1 = data1.map(["data"], operations=(lambda x : x))
|
||||||
|
|
||||||
|
data1 = data1.batch(2)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_case_0()
|
Loading…
Reference in New Issue