Migrate 3 pre passes to IR optimizer, namely, cache_error_pass, epoch_injection, and removal_pass

This commit is contained in:
Nat Sutyanyong 2020-11-27 16:45:00 -05:00
parent 73c91e05b1
commit d69a29a44e
61 changed files with 1095 additions and 359 deletions

View File

@ -574,7 +574,7 @@ Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize))); std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
return pre; return pre;
}); });
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1)); RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
TensorRow row; TensorRow row;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
int64_t row_cnt = 0; int64_t row_cnt = 0;

View File

@ -214,7 +214,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// The driver of the prepare phase of the execution tree. // The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases // Prepare phase consists of three sub phases
// //
// 1. PrepareTreePreAction() // 1. PreAction()
// Compulsory transformation/action pre optimization. // Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion // For example, CacheOp Insertion
// //
@ -222,41 +222,44 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// Optimization transformation/action, optional // Optimization transformation/action, optional
// For example, MapOp Fusion // For example, MapOp Fusion
// //
// 3. PrepareTreePostAction() // 3. PostAction()
// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return // @return Status - The error code return
Status ExecutionTree::Prepare(int32_t num_epochs) { Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs; num_epochs_ = num_epochs;
partially_prepare_ = partial;
// Pre optimization compulsory transformation // Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction()); RETURN_IF_NOT_OK(this->PreAction());
// If optional optimizations are enabled // If optional optimizations are enabled
if (optimize_) { if (optimize_) {
RETURN_IF_NOT_OK(this->Optimize()); RETURN_IF_NOT_OK(this->Optimize());
} }
// Post optimization compulsory transformation // Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePostAction()); RETURN_IF_NOT_OK(this->PostAction());
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
// Existing transformation implementation, will be removed later // Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated()); RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK(); return Status::OK();
} }
Status ExecutionTree::PrepareTreePreAction() { Status ExecutionTree::PreAction() {
bool modified = false; bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions; std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions // Construct pre actions
if (!partially_prepare_) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheErrorPass>()); pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif #endif
pre_actions.push_back(std::make_unique<EpochInjectionPass>()); pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>()); pre_actions.push_back(std::make_unique<RemovalPass>());
#ifndef ENABLE_ANDROID }
pre_actions.push_back(std::make_unique<CacheTransformPass>());
#endif
// this offers a way to override the preset optimization pass with customized ones // this offers a way to override the preset optimization pass with customized ones
// this is used when certain nodes are removed for tree getters // this is used when certain nodes are removed for tree getters
@ -276,15 +279,17 @@ Status ExecutionTree::PrepareTreePreAction() {
return Status::OK(); return Status::OK();
} }
Status ExecutionTree::PrepareTreePostAction() { Status ExecutionTree::PostAction() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
bool modified = false; bool modified = false;
OptPass post_actions; OptPass post_actions;
// Construct pre actions // Construct pre actions
MS_LOG(INFO) << "Running post pass loops."; MS_LOG(INFO) << "Running post pass loops.";
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<CacheTransformPass>());
post_actions.push_back(std::make_unique<RepeatPass>()); post_actions.push_back(std::make_unique<RepeatPass>());
#endif #endif
@ -340,9 +345,6 @@ Status ExecutionTree::PrepareDeprecated() {
// Recursive function used during prepare phase to visit a node and drive any pre- and post- // Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk. // node actions during a tree walk.
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) { Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
// execute PreAction
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());
// Before going down into children, make any prepare flags updates based on this operator. // Before going down into children, make any prepare flags updates based on this operator.
uint32_t op_prep_flags = dataset_op->PrepareFlags(); uint32_t op_prep_flags = dataset_op->PrepareFlags();
BitSet(&prepare_flags_, op_prep_flags); BitSet(&prepare_flags_, op_prep_flags);

View File

@ -169,7 +169,7 @@ class ExecutionTree {
// The driver of the prepare phase of the execution tree. // The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases // Prepare phase consists of three sub phases
// //
// 1. PrepareTreePreAction() // 1. PreAction()
// Compulsory transformation/action pre optimization. // Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion // For example, CacheOp Insertion
// //
@ -177,20 +177,20 @@ class ExecutionTree {
// Optimization transformation/action, optional // Optimization transformation/action, optional
// For example, MapOp Fusion // For example, MapOp Fusion
// //
// 3. PrepareTreePostAction() // 3. PostAction()
// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return // @return Status - The error code return
Status Prepare(int num_epochs = -1); Status Prepare(int num_epochs = -1, bool partial = false);
// Compulsory transformation/action pre optimization. // Compulsory transformation/action pre optimization.
// @return Status - The error code return // @return Status - The error code return
Status PrepareTreePreAction(); Status PreAction();
// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// @return Status - The error code return // @return Status - The error code return
Status PrepareTreePostAction(); Status PostAction();
// Optimization transformation/action, optional. // Optimization transformation/action, optional.
// @return Status - The error code return // @return Status - The error code return
@ -281,6 +281,7 @@ class ExecutionTree {
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations bool optimize_; // Flag to enable optional optimizations
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -139,5 +140,16 @@ Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for IRNodePass
Status BatchNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BatchNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status BatchNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BatchNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -74,6 +74,18 @@ class BatchNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
int32_t batch_size_; int32_t batch_size_;
bool drop_remainder_; bool drop_remainder_;

View File

@ -46,12 +46,40 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() { std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_); element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_);
return node; return node;
} }
void BucketBatchByLengthNode::Print(std::ostream &out) const { void BucketBatchByLengthNode::Print(std::ostream &out) const {
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; out << Name() + "(columns:" + PrintColumns(column_names_);
int i = 0;
for (auto it : bucket_boundaries_) {
if (i == 0) {
out << ",bucket_boundaries:{";
}
out << it;
if (i < bucket_boundaries_.size() - 1) {
out << ",";
} else {
out << "}";
}
i++;
}
i = 0;
for (auto it : bucket_batch_sizes_) {
if (i == 0) {
out << ",bucket_batch_sizes:{";
}
out << it;
if (i < bucket_batch_sizes_.size() - 1) {
out << ",";
} else {
out << "}";
}
i++;
}
out << ")";
} }
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

View File

@ -90,14 +90,14 @@ Status BuildSentenceVocabNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { Status BuildSentenceVocabNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified); return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { Status BuildSentenceVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified); return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
} }

View File

@ -59,17 +59,17 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::shared_ptr<SentencePieceVocab> vocab_; std::shared_ptr<SentencePieceVocab> vocab_;

View File

@ -85,14 +85,14 @@ Status BuildVocabNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status BuildVocabNode::Accept(NodePass *p, bool *modified) { Status BuildVocabNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildVocabNode>(), modified); return p->Visit(shared_from_base<BuildVocabNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { Status BuildVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified); return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
} }

View File

@ -58,17 +58,17 @@ class BuildVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::shared_ptr<Vocab> vocab_; std::shared_ptr<Vocab> vocab_;

View File

@ -39,8 +39,10 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
} }
std::shared_ptr<DatasetNode> ConcatNode::Copy() { std::shared_ptr<DatasetNode> ConcatNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
// create an empty vector to copy a concat // create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>()); auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler,
children_flag_and_nums_, children_start_end_index_);
return node; return node;
} }
@ -80,14 +82,14 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status ConcatNode::Accept(NodePass *p, bool *modified) { Status ConcatNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ConcatNode>(), modified); return p->Visit(shared_from_base<ConcatNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { Status ConcatNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified); return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
} }

View File

@ -66,17 +66,17 @@ class ConcatNode : public DatasetNode {
std::vector<std::pair<int, int>> children_flag_and_nums_; std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_; std::vector<std::pair<int, int>> children_start_end_index_;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset

View File

@ -242,9 +242,27 @@ DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) {
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
} }
const bool DatasetNode::IsTree() const {
bool is_tree = true;
if (this->parent_.size() > 1) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
return false;
}
for (const auto &child : children_) {
is_tree = child->IsTree();
if (!is_tree) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
break;
}
}
return is_tree;
}
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied // this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() { std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy(); std::shared_ptr<DatasetNode> new_node = this->Copy();
// temporary fix to set the num_workers to the new node.
new_node->SetNumWorkers(this->num_workers_);
for (const auto &child : children_) { for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy()); new_node->AddChild(child->DeepCopy());
} }
@ -298,12 +316,31 @@ void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
children_.push_back(child); children_.push_back(child);
child->parent_.push_back(this); child->parent_.push_back(this);
} else if (child != nullptr) { } else if (child != nullptr) {
MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr."; MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
children_.push_back(child); children_.push_back(child);
child->parent_.push_back(this); child->parent_.push_back(this);
} }
} }
// Insert a node as a child of this node. This node's children becomes the children of the inserted node.
Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) {
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children.");
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent.");
for (auto child : children_) {
node->children_.push_back(child);
child->parent_.clear();
child->parent_.push_back(node.get());
}
// Then establish the new parent-child relationship with the new parent.
children_.clear();
children_.push_back(node);
node->parent_.clear();
node->parent_.push_back(this);
return Status::OK();
}
// Remove this node from its parent. Add the child of this node to its parent. // Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child // for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() { Status DatasetNode::Remove() {
@ -325,14 +362,14 @@ Status DatasetNode::Remove() {
} }
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit. // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Status DatasetNode::Accept(NodePass *p, bool *modified) { Status DatasetNode::Accept(IRNodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one. // This method will only be called if its derived class does not implement one.
return p->Visit(shared_from_this(), modified); return p->Visit(shared_from_this(), modified);
} }
// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
// after all child nodes are visited. // after all child nodes are visited.
Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { Status DatasetNode::AcceptAfter(IRNodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one. // This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified); return p->VisitAfter(shared_from_this(), modified);
} }
@ -369,17 +406,5 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
} }
} }
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SourceNode>(), modified);
}
// Visitor accepting method for NodePass
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -32,7 +32,7 @@ namespace dataset {
class Dataset; class Dataset;
class SamplerObj; class SamplerObj;
class NodePass; class IRNodePass;
class DatasetSizeGetter; class DatasetSizeGetter;
// Names for non-leaf IR node // Names for non-leaf IR node
@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Establish the parent-child relationship between this node and its child. /// \brief Establish the parent-child relationship between this node and its child.
void AddChild(std::shared_ptr<DatasetNode> child); void AddChild(std::shared_ptr<DatasetNode> child);
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
Status InsertBelow(std::shared_ptr<DatasetNode> node);
/// \brief detach this node from its parent, add its child (if any) to its parent /// \brief detach this node from its parent, add its child (if any) to its parent
/// \return error code, return error if node has more than 1 children /// \return error code, return error if node has more than 1 children
Status Remove(); Status Remove();
@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return True if the data of this node will be cached /// \return True if the data of this node will be cached
const bool IsCached() const { return (cache_ != nullptr); } const bool IsCached() const { return (cache_ != nullptr); }
/// \brief Check if this node is a tree
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
const bool IsTree() const;
/// \brief Check if this node is a leaf node.
/// \return True if this is a leaf node.
const bool IsLeaf() const { return children_.empty(); }
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
/// \return True if a cache-enabled operator is an ancestor of this node
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
void HasCacheAbove() { descendant_of_cache_ = true; }
/// \brief Setter function for runtime number of workers /// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator /// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object /// \return Shared pointer to the original object
@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
return std::static_pointer_cast<Derived>(shared_from_this()); return std::static_pointer_cast<Derived>(shared_from_this());
} }
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up /// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited. /// visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access. /// \notes Subclass needs to override this if it requires special node visit access.
@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
virtual Status Accept(NodePass *p, bool *modified); virtual Status Accept(IRNodePass *p, bool *modified);
/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited. /// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access. /// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details. /// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified); virtual Status AcceptAfter(IRNodePass *p, bool *modified);
virtual bool IsSizeDefined() { return true; } virtual bool IsSizeDefined() { return true; }
@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
std::string PrintColumns(const std::vector<std::string> &columns) const; std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const; void PrintNode(std::ostream &out, int *level) const;
};
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
class SourceNode : public DatasetNode {
public:
/// \brief Constructor
SourceNode() : DatasetNode() {}
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
/// \brief Destructor
~SourceNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
protected:
bool mappable_; bool mappable_;
bool descendant_of_cache_;
}; };
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
class MappableSourceNode : public SourceNode { class MappableSourceNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
MappableSourceNode() : SourceNode() { mappable_ = true; } MappableSourceNode() : DatasetNode() { mappable_ = true; }
/// \brief Constructor that initializes the cache /// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache /// \param dataset_cache DatasetCache
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = true; mappable_ = true;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
} }
/// \brief Destructor /// \brief Destructor
@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode {
}; };
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
class NonMappableSourceNode : public SourceNode { class NonMappableSourceNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
NonMappableSourceNode() : SourceNode() { mappable_ = false; } NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
/// \brief Constructor that initializes the cache /// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache /// \param dataset_cache DatasetCache
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = false; mappable_ = false;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
} }
/// \brief Destructor /// \brief Destructor
@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode {
/// \return Name of the current node /// \return Name of the current node
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
}; };
// NonLeafNode represents operations over data in a pipeline.
class NonLeafNode : public DatasetNode {
public:
/// \brief Constructor
NonLeafNode() = default;
/// \brief Destructor
~NonLeafNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// SinkNode represents the end node of a pipeline where the data is pushed out
class SinkNode : public DatasetNode {
public:
/// \brief Constructor
SinkNode() = default;
/// \brief Destructor
~SinkNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

View File

@ -32,8 +32,9 @@ EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epo
// The root node's parent must set to null pointer. // The root node's parent must set to null pointer.
this->AddChild(child); this->AddChild(child);
} }
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_); auto node = std::make_shared<EpochCtrlNode>(num_epochs_);
return node; return node;
} }

View File

@ -29,7 +29,10 @@ namespace dataset {
class EpochCtrlNode : public DatasetNode { class EpochCtrlNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {}
/// \brief Constructor
EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
/// \brief Destructor /// \brief Destructor
~EpochCtrlNode() = default; ~EpochCtrlNode() = default;

View File

@ -60,14 +60,14 @@ Status FilterNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status FilterNode::Accept(NodePass *p, bool *modified) { Status FilterNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<FilterNode>(), modified); return p->Visit(shared_from_base<FilterNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { Status FilterNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<FilterNode>(), modified); return p->VisitAfter(shared_from_base<FilterNode>(), modified);
} }

View File

@ -58,17 +58,17 @@ class FilterNode : public DatasetNode {
bool IsSizeDefined() override { return false; }; bool IsSizeDefined() override { return false; };
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::shared_ptr<TensorOp> predicate_; std::shared_ptr<TensorOp> predicate_;

View File

@ -42,14 +42,16 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
} }
std::shared_ptr<DatasetNode> MapNode::Copy() { std::shared_ptr<DatasetNode> MapNode::Copy() {
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, std::vector<std::shared_ptr<TensorOperation>> operations = operations_;
auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_,
callbacks_); callbacks_);
return node; return node;
} }
void MapNode::Print(std::ostream &out) const { void MapNode::Print(std::ostream &out) const {
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
",<project_cols>" + ",...)"; ",<project_cols>" + ",num_tensor_ops:"
<< operations_.size() << ",...)";
} }
Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
@ -101,14 +103,14 @@ Status MapNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status MapNode::Accept(NodePass *p, bool *modified) { Status MapNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MapNode>(), modified); return p->Visit(shared_from_base<MapNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status MapNode::AcceptAfter(NodePass *p, bool *modified) { Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MapNode>(), modified); return p->VisitAfter(shared_from_base<MapNode>(), modified);
} }

View File

@ -63,17 +63,17 @@ class MapNode : public DatasetNode {
const auto &TensorOperations() const { return operations_; } const auto &TensorOperations() const { return operations_; }
auto &TensorOperations() { return operations_; } auto &TensorOperations() { return operations_; }
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::vector<std::shared_ptr<TensorOperation>> operations_; std::vector<std::shared_ptr<TensorOperation>> operations_;

View File

@ -70,14 +70,14 @@ Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) { Status RepeatNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RepeatNode>(), modified); return p->Visit(shared_from_base<RepeatNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { Status RepeatNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RepeatNode>(), modified); return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
} }

View File

@ -66,17 +66,17 @@ class RepeatNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
int32_t repeat_count_; int32_t repeat_count_;

View File

@ -72,14 +72,14 @@ Status RootNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status RootNode::Accept(NodePass *p, bool *modified) { Status RootNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RootNode>(), modified); return p->Visit(shared_from_base<RootNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status RootNode::AcceptAfter(NodePass *p, bool *modified) { Status RootNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RootNode>(), modified); return p->VisitAfter(shared_from_base<RootNode>(), modified);
} }

View File

@ -58,17 +58,17 @@ class RootNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
int32_t num_epochs_; int32_t num_epochs_;

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "minddata/dataset/engine/datasetops/skip_op.h" #include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
@ -70,5 +71,16 @@ Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for IRNodePass
Status SkipNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SkipNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status SkipNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SkipNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -64,6 +64,18 @@ class SkipNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
int32_t skip_count_; int32_t skip_count_;
}; };

View File

@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch
sampler_(sampler) {} sampler_(sampler) {}
std::shared_ptr<DatasetNode> AlbumNode::Copy() { std::shared_ptr<DatasetNode> AlbumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
return node; return node;
} }

View File

@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
extensions_(extensions) {} extensions_(extensions) {}
std::shared_ptr<DatasetNode> CelebANode::Copy() { std::shared_ptr<DatasetNode> CelebANode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
return node; return node;
} }

View File

@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> Cifar100Node::Copy() { std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }

View File

@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> Cifar10Node::Copy() { std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }

View File

@ -208,7 +208,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
RETURN_IF_NOT_OK(clue_op->Init()); RETURN_IF_NOT_OK(clue_op->Init());
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0; int64_t num_rows = 0;

View File

@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation
sampler_(sampler) {} sampler_(sampler) {}
std::shared_ptr<DatasetNode> CocoNode::Copy() { std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
return node; return node;
} }

View File

@ -119,7 +119,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
RETURN_IF_NOT_OK(csv_op->Init()); RETURN_IF_NOT_OK(csv_op->Init());
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0; int64_t num_rows = 0;

View File

@ -33,8 +33,16 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
column_names_(column_names), column_names_(column_names),
column_types_(column_types) {} column_types_(column_types) {}
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}
std::shared_ptr<DatasetNode> GeneratorNode::Copy() { std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); std::shared_ptr<GeneratorNode> node;
if (schema_ == nullptr) {
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
} else {
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
}
return node; return node;
} }
@ -42,9 +50,6 @@ void GeneratorNode::Print(std::ostream &out) const {
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)"; out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
} }
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}
Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();

View File

@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
exts_(extensions) {} exts_(extensions) {}
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = auto node =
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
return node; return node;

View File

@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
sampler_(sampler) {} sampler_(sampler) {}
std::shared_ptr<DatasetNode> ManifestNode::Copy() { std::shared_ptr<DatasetNode> ManifestNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
return node; return node;
} }

View File

@ -54,12 +54,13 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
std::shared_ptr<DatasetNode> MindDataNode::Copy() { std::shared_ptr<DatasetNode> MindDataNode::Copy() {
std::shared_ptr<MindDataNode> node; std::shared_ptr<MindDataNode> node;
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
if (dataset_files_.empty()) { if (dataset_files_.empty()) {
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
} else { } else {
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
} }
node->SetSampleBytes(&sample_bytes_);
return node; return node;
} }

View File

@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> MnistNode::Copy() { std::shared_ptr<DatasetNode> MnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }

View File

@ -86,7 +86,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_IF_NOT_OK(text_file_op->Init()); RETURN_IF_NOT_OK(text_file_op->Init());
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0; int64_t num_rows = 0;

View File

@ -134,7 +134,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
RETURN_IF_NOT_OK(tf_reader_op->Init()); RETURN_IF_NOT_OK(tf_reader_op->Init());
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;

View File

@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
sampler_(sampler) {} sampler_(sampler) {}
std::shared_ptr<DatasetNode> VOCNode::Copy() { std::shared_ptr<DatasetNode> VOCNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
return node; return node;
} }

View File

@ -22,6 +22,7 @@
#include <algorithm> #include <algorithm>
#include "minddata/dataset/engine/datasetops/take_op.h" #include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
@ -68,5 +69,16 @@ Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for IRNodePass
Status TakeNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<TakeNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status TakeNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<TakeNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -64,6 +64,18 @@ class TakeNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
int32_t take_count_; int32_t take_count_;
}; };

View File

@ -104,14 +104,14 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status TransferNode::Accept(NodePass *p, bool *modified) { Status TransferNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<TransferNode>(), modified); return p->Visit(shared_from_base<TransferNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { Status TransferNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<TransferNode>(), modified); return p->VisitAfter(shared_from_base<TransferNode>(), modified);
} }

View File

@ -58,17 +58,17 @@ class TransferNode : public DatasetNode {
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified /// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override; Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::string queue_name_; std::string queue_name_;

View File

@ -79,14 +79,14 @@ Status ZipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
return Status::OK(); return Status::OK();
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status ZipNode::Accept(NodePass *p, bool *modified) { Status ZipNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ZipNode>(), modified); return p->Visit(shared_from_base<ZipNode>(), modified);
} }
// Visitor accepting method for NodePass // Visitor accepting method for IRNodePass
Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { Status ZipNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ZipNode>(), modified); return p->VisitAfter(shared_from_base<ZipNode>(), modified);
} }

View File

@ -64,19 +64,20 @@ class ZipNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private: private:
std::vector<std::shared_ptr<DatasetNode>> datasets_; std::vector<std::shared_ptr<DatasetNode>> datasets_;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset

View File

@ -6,9 +6,12 @@ add_library(engine-opt OBJECT
post/repeat_pass.cc post/repeat_pass.cc
pre/cache_error_pass.cc pre/cache_error_pass.cc
pre/cache_transform_pass.cc pre/cache_transform_pass.cc
pre/cache_validation_pass.cc
pre/epoch_ctrl_pass.cc
pre/epoch_injection_pass.cc pre/epoch_injection_pass.cc
pre/getter_pass.cc pre/getter_pass.cc
pre/input_validation_pass.cc pre/input_validation_pass.cc
pre/node_removal_pass.cc
pre/removal_pass.cc pre/removal_pass.cc
util/printer_pass.cc util/printer_pass.cc
) )

View File

@ -87,7 +87,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
// Driver method for TreePass // Driver method for TreePass
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) { if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
} }
@ -95,7 +95,7 @@ Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
} }
// Driver method for NodePass // Driver method for NodePass
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) { if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
} }
@ -110,7 +110,7 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
} }
// Helper function to perform DFS visit // Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
bool m = false; bool m = false;
RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
@ -125,7 +125,7 @@ Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
} }
// Helper function to perform BFS visit // Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
bool m = false; bool m = false;
// Initialize bfs queue with root // Initialize bfs queue with root
@ -151,121 +151,113 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
} }
// For non-leaf IR node // For non-leaf IR node
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif #endif
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif #endif
// For leaf IR Node
Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
////////////////////////////////// //////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Driver method for TreePass // Driver method for TreePass

View File

@ -113,26 +113,18 @@ class GeneratorOp;
// The base class Pass is the basic unit of tree transformation. // The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here. // The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> { class IRPass : public std::enable_shared_from_this<IRPass> {
public: public:
// Run the transformation pass against the IR tree. // Run the transformation pass against the IR tree.
// @param root_ir - Pointer to the IR tree to be transformed. // @param root_ir - Pointer to the IR tree to be transformed.
// @param modified - Pointer to the modified flag, // @param modified - Pointer to the modified flag,
virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0; virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0;
////////////////////////////////// virtual ~IRPass() = default;
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
//////////////////////////////////
virtual ~Pass() = default;
}; };
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. // IRTreePass is a basic Pass class which performs transformation on IR tree directly.
class TreePass : public Pass { class IRTreePass : public IRPass {
public: public:
/// \brief Run the transformation pass against the IR tree. /// \brief Run the transformation pass against the IR tree.
/// \param[inout] root_ir Pointer to the IR tree to be transformed. /// \param[inout] root_ir Pointer to the IR tree to be transformed.
@ -145,44 +137,29 @@ class TreePass : public Pass {
/// \param[inout] Indicate if the tree was modified. /// \param[inout] Indicate if the tree was modified.
/// \return Status The error code return /// \return Status The error code return
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
Status Run(ExecutionTree *tree, bool *modified) final;
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
//////////////////////////////////
}; };
// NodePass is a base Pass class which performs transformation on node visiting. // IRNodePass is a base Pass class which performs transformation on node visiting.
// NodePass implements Visitor design pattern. // IRNodePass implements Visitor design pattern.
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
// and the other when all the descending nodes are visited. // and the other when all the descending nodes are visited.
// Actual transformation is done by implementing a new derived class of NodePass. // Actual transformation is done by implementing a new derived class of IRNodePass.
// The derived class will implement the method Visit()/VisitAfter() passing specified node types // The derived class will implement the method Visit()/VisitAfter() passing specified node types
// it wants to action on them, overriding the ones defined in NodePass. // it wants to action on them, overriding the ones defined in IRNodePass.
// If the derived class wants to perform the same action on all node types, // If the derived class wants to perform the same action on all node types,
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
// to call the Visit()/VisitAfter() in this parent NodePass class. // to call the Visit()/VisitAfter() in this parent IRNodePass class.
class NodePass : public Pass { class IRNodePass : public IRPass {
public: public:
// Tree traversal order // Tree traversal order
enum Order { DFS, BFS }; enum Order { DFS, BFS };
// Constructor // Constructor
// Default DFS traversal // Default DFS traversal
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; }
~NodePass() = default; ~IRNodePass() = default;
/// \brief Run the transformation pass against the IR tree /// \brief Run the transformation pass against the IR tree
/// \param[inout] root_ir Pointer to the IR tree to be transformed /// \param[inout] root_ir Pointer to the IR tree to be transformed
@ -251,12 +228,70 @@ class NodePass : public Pass {
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif #endif
// Leaf IR node
virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified);
////////////////////////////////// private:
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
// Tree traversal order of the NodePass
Order traversalOrder_;
};
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
public:
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
virtual ~Pass() = default;
};
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass {
public:
/// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
Status Run(ExecutionTree *tree, bool *modified) final;
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};
// NodePass is a base Pass class which performs transformation on node visiting.
// NodePass implements Visitor design pattern.
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
// and the other when all the descending nodes are visited.
// Actual transformation is done by implementing a new derived class of NodePass.
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
// it wants to action on them, overriding the ones defined in NodePass.
// If the derived class wants to perform the same action on all node types,
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
// to call the Visit()/VisitAfter() in this parent NodePass class.
class NodePass : public Pass {
public:
// Tree traversal order
enum Order { DFS, BFS };
// Constructor
// Default DFS traversal
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
~NodePass() = default;
/// \brief Run the transformation pass against the execution tree /// \brief Run the transformation pass against the execution tree
/// \param[inout] tree Pointer to the execution tree to be transformed /// \param[inout] tree Pointer to the execution tree to be transformed
/// \param[inout] modified Indicator if the tree was changed /// \param[inout] modified Indicator if the tree was changed
@ -326,27 +361,18 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif #endif
//////////////////////////////////
private: private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Helper function to perform DFS visit // Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
// Helper function to perform BFS visit // Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified); Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
//////////////////////////////////
// Tree traversal order of the NodePass // Tree traversal order of the NodePass
Order traversalOrder_; Order traversalOrder_;
}; };
//////////////////////////////////
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,163 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/include/transforms.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {}
// Returns an error if BatchNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<BatchNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if ConcatNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ConcatNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if FilterNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<FilterNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if SkipNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<SkipNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if TakeNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<TakeNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if ZipNode exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ZipNode>): visiting " << node->Name() << ".";
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache.");
}
if (node->IsCached()) {
RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached.");
}
return Status::OK();
}
// Returns an error if MapNode with non-deterministic tensor operations exists under a cache
Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<MapNode>): visiting " << node->Name() << ".";
if (node->IsCached()) {
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported.");
}
// If Map is created to be cached, set the flag indicating we found an operation with a cache.
is_cached_ = true;
auto tfuncs = node->TensorOperations();
for (size_t i = 0; i < tfuncs.size(); i++) {
if (tfuncs[i]->IsRandomOp()) {
RETURN_STATUS_UNEXPECTED(
"MapNode with non-deterministic operations is not supported as a descendant of cache.");
}
}
}
return Status::OK();
}
// Flag an error if we have a cache over another cache
Status CacheValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<DatasetNode>): visiting " << node->Name() << ".";
if (node->IsCached()) {
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported.");
}
// If this node is created to be cached, set the flag.
is_cached_ = true;
}
if (node->IsLeaf() && node->IsMappable()) {
is_mappable_ = true;
}
return Status::OK();
}
// Returns an error if MappableSource <- Repeat <- Node with a cache
// Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem.
Status CacheValidationPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<RepeatNode>): visiting " << node->Name() << ".";
if (is_cached_ && is_mappable_) {
RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported.");
}
return Status::OK();
}
Status CacheValidationPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<DatasetNode>): visiting " << node->Name() << ".";
// Reset the flag when all descendants are visited
if (node->IsCached()) {
is_cached_ = false;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class CacheValidationPass cache_validation_pass.h
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
class CacheValidationPass : public IRNodePass {
public:
/// \brief Constructor
CacheValidationPass();
/// \brief Destructor
~CacheValidationPass() = default;
/// \brief Returns an error if BatchNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
/// \brief Returns an error if ConcatNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override;
/// \brief Returns an error if FilterNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override;
/// \brief Returns an error if SkipNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
/// \brief Returns an error if TakeNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
/// \brief Returns an error if ZipNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override;
/// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
/// \brief Returns an error if there is a cache over another cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override;
/// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
private:
bool is_cached_;
bool is_mappable_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
namespace mindspore {
namespace dataset {
// constructor
EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetNode> node)
: injection_point_(nullptr), num_epochs_(-1) {}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *modified) {
// The injection is at the child of the root node
injection_point_ = node;
num_epochs_ = node->num_epochs();
return Status::OK();
}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
#endif
Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
// Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
// Move the injection point to the child of this node.
injection_point_ = node;
return Status::OK();
}
// constructor
EpochCtrlPass::EpochCtrlPass() {}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status EpochCtrlPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the EpochInjectionPass object.
EpochCtrlPass::InjectionFinder finder(root_ir);
RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
std::shared_ptr<DatasetNode> parent = finder.injection_point();
int32_t num_epochs = finder.num_epochs();
if (num_epochs != 1 && parent != nullptr) {
CHECK_FAIL_RETURN_UNEXPECTED(parent->Children().size() == 1, "EpochCtrl must be injected on only one child.");
auto epoch_ctrl_node = std::make_shared<EpochCtrlNode>(num_epochs);
RETURN_IF_NOT_OK(parent->InsertBelow(epoch_ctrl_node));
}
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class EpochInjectionPass epoch_ctrl_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class EpochCtrlPass : public IRTreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class InjectionFinder : public IRNodePass {
public:
/// \brief Constructor
explicit InjectionFinder(std::shared_ptr<DatasetNode> node);
/// \brief Destructor
~InjectionFinder() = default;
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<RootNode> node, bool *modified) override;
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override;
#ifndef ENABLE_ANDROID
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override;
#endif
/// \brief Register the TransferNode for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override;
/// \brief Getter
std::shared_ptr<DatasetNode> injection_point() { return injection_point_; }
/// \brief Getter
int32_t num_epochs() { return num_epochs_; }
private:
std::shared_ptr<DatasetNode> injection_point_;
int32_t num_epochs_;
};
public:
/// \brief Constructor
EpochCtrlPass();
/// \brief Destructor
~EpochCtrlPass() = default;
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_

View File

@ -26,7 +26,7 @@ namespace dataset {
/// \class InputValidationPass /// \class InputValidationPass
/// \brief This is a parse pass that validates input parameters of the IR tree. /// \brief This is a parse pass that validates input parameters of the IR tree.
class InputValidationPass : public NodePass { class InputValidationPass : public IRNodePass {
/// \brief Runs a validatation pass to check input parameters /// \brief Runs a validatation pass to check input parameters
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] *modified indicates whether the node has been visited /// \param[inout] *modified indicates whether the node has been visited

View File

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
namespace mindspore {
namespace dataset {
NodeRemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
// Identifies the subtree below this node as a cached descendant tree.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Node removal pass: Operation with cache found, identified descendant tree.";
if (node->IsCached()) {
is_caching_ = true;
}
return Status::OK();
}
// Resets the tracking of the cache within the tree
Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: Descendant walk is complete.";
if (is_caching_ && node->IsLeaf()) {
// Mark this leaf node to indicate it is a descendant of an operator with cache.
// This is currently used in non-mappable dataset (leaf) nodes to not add a ShuffleOp in DatasetNode::Build().
node->HasCacheAbove();
}
is_caching_ = false;
return Status::OK();
}
// Perform ShuffleOp removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
*modified = false;
#if 0
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "Shuffle under an operation with cache is identified for removal.";
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
#endif
return Status::OK();
}
// constructor
NodeRemovalPass::NodeRemovalPass() {}
// Walk the tree to collect the nodes to remove, then removes them.
Status NodeRemovalPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
MS_LOG(INFO) << "Pre pass: node removal pass started.";
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<NodeRemovalPass::RemovalNodes> removal_nodes = std::make_unique<NodeRemovalPass::RemovalNodes>();
RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified));
// Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes->nodes_to_remove()) {
RETURN_IF_NOT_OK(node->Remove());
}
MS_LOG(INFO) << "Pre pass: node removal pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class RemovalPass removal_pass.h
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class NodeRemovalPass : public IRTreePass {
/// \class RemovalNodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public IRNodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes();
/// \brief Destructor
~RemovalNodes() = default;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Perform ShuffleNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override;
/// \brief Getter
/// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
private:
bool is_caching_;
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_;
};
public:
/// \brief Constructor
NodeRemovalPass();
/// \brief Destructor
~NodeRemovalPass() = default;
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_

View File

@ -17,34 +17,25 @@
#include "minddata/dataset/engine/tree_adapter.h" #include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/core/client.h" #include "minddata/dataset/core/client.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
// Vector of actions in validation pass // Vector of actions in pre-pass phase
std::vector<std::unique_ptr<NodePass>> validations; std::vector<std::unique_ptr<IRPass>> actions;
MS_LOG(INFO) << "Running pre pass loops."; MS_LOG(INFO) << "Running pre pass loops.";
validations.push_back(std::make_unique<InputValidationPass>()); actions.push_back(std::make_unique<InputValidationPass>());
actions.push_back(std::make_unique<CacheValidationPass>());
// Vector of flags for each action actions.push_back(std::make_unique<NodeRemovalPass>());
// Apply validation actions actions.push_back(std::make_unique<EpochCtrlPass>());
for (auto i = 0; i < validations.size(); i++) {
auto modified = false;
// InputValidationPass does not change the IR tree. We don't need to capture the "modified" value.
RETURN_IF_NOT_OK(validations[i]->Run(ir, &modified));
}
// Vector of actions in pre-pass phase
std::vector<std::unique_ptr<Pass>> actions;
// We will gradually move CacheErrorPass, EpochInjectionPass, CacheTransformPass
// from ExecutionTree::PrepareTreePreAction to here.
// Vector of flags for each action // Vector of flags for each action
std::vector<bool> modified(actions.size(), false); std::vector<bool> modified(actions.size(), false);
@ -60,7 +51,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) { Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
// Vector of optimizations // Vector of optimizations
std::vector<std::unique_ptr<NodePass>> optimizations; std::vector<std::unique_ptr<IRNodePass>> optimizations;
MS_LOG(INFO) << "Running optimization pass loops"; MS_LOG(INFO) << "Running optimization pass loops";
// We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here. // We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here.
@ -79,7 +70,7 @@ Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
// Vector of actions in post-pass phase // Vector of actions in post-pass phase
std::vector<std::unique_ptr<Pass>> actions; std::vector<std::unique_ptr<IRPass>> actions;
MS_LOG(INFO) << "Running post pass loops."; MS_LOG(INFO) << "Running post pass loops.";
// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.
@ -96,10 +87,6 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
} }
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// Check if pipeline is valid or not
CHECK_FAIL_RETURN_UNEXPECTED(ir->Parent().size() <= 1,
"The data pipeline is not a tree (i.e. one node has two consumers)");
// Build the DatasetOp ExecutionTree from the optimized IR tree // Build the DatasetOp ExecutionTree from the optimized IR tree
std::vector<std::shared_ptr<DatasetOp>> ops; std::vector<std::shared_ptr<DatasetOp>> ops;
RETURN_IF_NOT_OK(ir->Build(&ops)); RETURN_IF_NOT_OK(ir->Build(&ops));
@ -130,8 +117,12 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
RETURN_UNEXPECTED_IF_NULL(input_ir); RETURN_UNEXPECTED_IF_NULL(input_ir);
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
// We will first walk the input tree to sanity check this is not a graph
// Flag an error when it is not a tree
CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)");
// Copy the input IR tree and insert under the root node // Copy the input IR tree and insert under the root node
// Create a root node to host the input IR tree, the deepcopied tree will be passed to optimization pass // Create a root node to host the new copy of the input IR tree to pass to the optimizer
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs); auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
@ -151,11 +142,9 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
// This will evolve in the long run // This will evolve in the long run
tree_ = std::make_unique<ExecutionTree>(); tree_ = std::make_unique<ExecutionTree>();
// Build the Execution tree from the child of the root node // Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree
std::shared_ptr<DatasetOp> root_op; std::shared_ptr<DatasetOp> root_op;
// input_ir is the ir node before the deepcopy. RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op));
// We will replace input_ir with root_ir->Children()[0] once IR optimizer is in
RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
@ -163,7 +152,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
// Note: We will gradually move the pre pass, optimizer pass, and post pass // Note: We will gradually move the pre pass, optimizer pass, and post pass
// on ExecutionTree to perform on IR tree. // on ExecutionTree to perform on IR tree.
// Prepare the tree // Prepare the tree
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true));
// After the tree is prepared, the col_name_id_map can safely be obtained // After the tree is prepared, the col_name_id_map can safely be obtained
column_name_map_ = tree_->root()->column_name_id_map(); column_name_map_ = tree_->root()->column_name_id_map();

View File

@ -44,7 +44,7 @@ Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
} }
void NormalizeOp::Print(std::ostream &out) const { void NormalizeOp::Print(std::ostream &out) const {
out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl; out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl;
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -83,7 +83,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) {
}; };
exe_tree->SetPrePassOverride(pass); exe_tree->SetPrePassOverride(pass);
ASSERT_OK(exe_tree->PrepareTreePreAction()); ASSERT_OK(exe_tree->PreAction());
std::stringstream ss; std::stringstream ss;
// print the tree in std::string as a way to verify that nodes are indeed removed // print the tree in std::string as a way to verify that nodes are indeed removed
@ -124,7 +124,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
}; };
exe_tree->SetPrePassOverride(pass); exe_tree->SetPrePassOverride(pass);
ASSERT_OK(exe_tree->PrepareTreePreAction()); ASSERT_OK(exe_tree->PreAction());
std::stringstream ss; std::stringstream ss;
// print the tree in std::string as a way to verify that nodes are indeed removed // print the tree in std::string as a way to verify that nodes are indeed removed
exe_tree->Print(ss); exe_tree->Print(ss);

View File

@ -237,7 +237,7 @@ def test_cache_map_failure1():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1): for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1 num_iter += 1
assert "Nested cache operations is not supported!" in str(e.value) assert "Nested cache operations" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure1 Ended.\n') logger.info('test_cache_failure1 Ended.\n')
@ -279,7 +279,7 @@ def test_cache_map_failure2():
num_iter = 0 num_iter = 0
for _ in dsz.create_dict_iterator(): for _ in dsz.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure2 Ended.\n') logger.info('test_cache_failure2 Ended.\n')
@ -319,7 +319,7 @@ def test_cache_map_failure3():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure3 Ended.\n') logger.info('test_cache_failure3 Ended.\n')
@ -361,7 +361,7 @@ def test_cache_map_failure4():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure4 Ended.\n') logger.info('test_cache_failure4 Ended.\n')
@ -402,7 +402,7 @@ def test_cache_map_failure5():
num_iter = 0 num_iter = 0
for _ in data.create_dict_iterator(): for _ in data.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value) assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure5 Ended.\n') logger.info('test_cache_failure5 Ended.\n')
@ -522,7 +522,7 @@ def test_cache_map_failure8():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1): for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1 num_iter += 1
assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure8 Ended.\n') logger.info('test_cache_failure8 Ended.\n')
@ -564,7 +564,7 @@ def test_cache_map_failure9():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure9 Ended.\n') logger.info('test_cache_failure9 Ended.\n')
@ -606,7 +606,7 @@ def test_cache_map_failure10():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "SkipOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure10 Ended.\n') logger.info('test_cache_failure10 Ended.\n')
@ -655,13 +655,13 @@ def test_cache_map_split1():
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:
num_iter = 0 num_iter = 0
for _ in ds2.create_dict_iterator(): for _ in ds2.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
logger.info('test_cache_split1 Ended.\n') logger.info('test_cache_split1 Ended.\n')