forked from mindspore-Ecosystem/mindspore
Migrate 3 pre passes to IR optimizer, namely, cache_error_pass, epoch_injection, and removal_pass
This commit is contained in:
parent
73c91e05b1
commit
d69a29a44e
|
@ -574,7 +574,7 @@ Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *
|
||||||
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
|
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
|
||||||
return pre;
|
return pre;
|
||||||
});
|
});
|
||||||
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1));
|
RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
||||||
int64_t row_cnt = 0;
|
int64_t row_cnt = 0;
|
||||||
|
|
|
@ -214,7 +214,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
||||||
// The driver of the prepare phase of the execution tree.
|
// The driver of the prepare phase of the execution tree.
|
||||||
// Prepare phase consists of three sub phases
|
// Prepare phase consists of three sub phases
|
||||||
//
|
//
|
||||||
// 1. PrepareTreePreAction()
|
// 1. PreAction()
|
||||||
// Compulsory transformation/action pre optimization.
|
// Compulsory transformation/action pre optimization.
|
||||||
// For example, CacheOp Insertion
|
// For example, CacheOp Insertion
|
||||||
//
|
//
|
||||||
|
@ -222,41 +222,44 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
||||||
// Optimization transformation/action, optional
|
// Optimization transformation/action, optional
|
||||||
// For example, MapOp Fusion
|
// For example, MapOp Fusion
|
||||||
//
|
//
|
||||||
// 3. PrepareTreePostAction()
|
// 3. PostAction()
|
||||||
// Compulsory transformation/action post optimization.
|
// Compulsory transformation/action post optimization.
|
||||||
// For example, repeatOp inlining
|
// For example, repeatOp inlining
|
||||||
//
|
//
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status ExecutionTree::Prepare(int32_t num_epochs) {
|
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
|
||||||
num_epochs_ = num_epochs;
|
num_epochs_ = num_epochs;
|
||||||
|
partially_prepare_ = partial;
|
||||||
|
|
||||||
// Pre optimization compulsory transformation
|
// Pre optimization compulsory transformation
|
||||||
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
|
RETURN_IF_NOT_OK(this->PreAction());
|
||||||
|
|
||||||
// If optional optimizations are enabled
|
// If optional optimizations are enabled
|
||||||
if (optimize_) {
|
if (optimize_) {
|
||||||
RETURN_IF_NOT_OK(this->Optimize());
|
RETURN_IF_NOT_OK(this->Optimize());
|
||||||
}
|
}
|
||||||
// Post optimization compulsory transformation
|
// Post optimization compulsory transformation
|
||||||
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
|
RETURN_IF_NOT_OK(this->PostAction());
|
||||||
|
|
||||||
|
// The tree is ready to be prepared.
|
||||||
|
tree_state_ = kDeTStatePrepare;
|
||||||
|
|
||||||
// Existing transformation implementation, will be removed later
|
// Existing transformation implementation, will be removed later
|
||||||
RETURN_IF_NOT_OK(this->PrepareDeprecated());
|
RETURN_IF_NOT_OK(this->PrepareDeprecated());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecutionTree::PrepareTreePreAction() {
|
Status ExecutionTree::PreAction() {
|
||||||
bool modified = false;
|
bool modified = false;
|
||||||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||||
// Construct pre actions
|
// Construct pre actions
|
||||||
|
if (!partially_prepare_) {
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
pre_actions.push_back(std::make_unique<CacheErrorPass>());
|
pre_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||||
#endif
|
#endif
|
||||||
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
|
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
|
||||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||||
#ifndef ENABLE_ANDROID
|
}
|
||||||
pre_actions.push_back(std::make_unique<CacheTransformPass>());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// this offers a way to override the preset optimization pass with customized ones
|
// this offers a way to override the preset optimization pass with customized ones
|
||||||
// this is used when certain nodes are removed for tree getters
|
// this is used when certain nodes are removed for tree getters
|
||||||
|
@ -276,15 +279,17 @@ Status ExecutionTree::PrepareTreePreAction() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecutionTree::PrepareTreePostAction() {
|
Status ExecutionTree::PostAction() {
|
||||||
// The tree is ready to be prepared.
|
|
||||||
tree_state_ = kDeTStatePrepare;
|
|
||||||
|
|
||||||
bool modified = false;
|
bool modified = false;
|
||||||
OptPass post_actions;
|
OptPass post_actions;
|
||||||
// Construct pre actions
|
// Construct pre actions
|
||||||
MS_LOG(INFO) << "Running post pass loops.";
|
MS_LOG(INFO) << "Running post pass loops.";
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
|
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
|
||||||
|
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
|
||||||
|
// This is because Python API binding to TensorOperation is still in progress.
|
||||||
|
post_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||||
|
post_actions.push_back(std::make_unique<CacheTransformPass>());
|
||||||
post_actions.push_back(std::make_unique<RepeatPass>());
|
post_actions.push_back(std::make_unique<RepeatPass>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -340,9 +345,6 @@ Status ExecutionTree::PrepareDeprecated() {
|
||||||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||||
// node actions during a tree walk.
|
// node actions during a tree walk.
|
||||||
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
|
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
|
||||||
// execute PreAction
|
|
||||||
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());
|
|
||||||
|
|
||||||
// Before going down into children, make any prepare flags updates based on this operator.
|
// Before going down into children, make any prepare flags updates based on this operator.
|
||||||
uint32_t op_prep_flags = dataset_op->PrepareFlags();
|
uint32_t op_prep_flags = dataset_op->PrepareFlags();
|
||||||
BitSet(&prepare_flags_, op_prep_flags);
|
BitSet(&prepare_flags_, op_prep_flags);
|
||||||
|
|
|
@ -169,7 +169,7 @@ class ExecutionTree {
|
||||||
// The driver of the prepare phase of the execution tree.
|
// The driver of the prepare phase of the execution tree.
|
||||||
// Prepare phase consists of three sub phases
|
// Prepare phase consists of three sub phases
|
||||||
//
|
//
|
||||||
// 1. PrepareTreePreAction()
|
// 1. PreAction()
|
||||||
// Compulsory transformation/action pre optimization.
|
// Compulsory transformation/action pre optimization.
|
||||||
// For example, CacheOp Insertion
|
// For example, CacheOp Insertion
|
||||||
//
|
//
|
||||||
|
@ -177,20 +177,20 @@ class ExecutionTree {
|
||||||
// Optimization transformation/action, optional
|
// Optimization transformation/action, optional
|
||||||
// For example, MapOp Fusion
|
// For example, MapOp Fusion
|
||||||
//
|
//
|
||||||
// 3. PrepareTreePostAction()
|
// 3. PostAction()
|
||||||
// Compulsory transformation/action post optimization.
|
// Compulsory transformation/action post optimization.
|
||||||
// For example, repeatOp inlining
|
// For example, repeatOp inlining
|
||||||
//
|
//
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status Prepare(int num_epochs = -1);
|
Status Prepare(int num_epochs = -1, bool partial = false);
|
||||||
|
|
||||||
// Compulsory transformation/action pre optimization.
|
// Compulsory transformation/action pre optimization.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status PrepareTreePreAction();
|
Status PreAction();
|
||||||
|
|
||||||
// Compulsory transformation/action post optimization.
|
// Compulsory transformation/action post optimization.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status PrepareTreePostAction();
|
Status PostAction();
|
||||||
|
|
||||||
// Optimization transformation/action, optional.
|
// Optimization transformation/action, optional.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
|
@ -281,6 +281,7 @@ class ExecutionTree {
|
||||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||||
bool optimize_; // Flag to enable optional optimizations
|
bool optimize_; // Flag to enable optional optimizations
|
||||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
||||||
|
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -139,5 +140,16 @@ Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status BatchNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->Visit(shared_from_base<BatchNode>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status BatchNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->VisitAfter(shared_from_base<BatchNode>(), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -74,6 +74,18 @@ class BatchNode : public DatasetNode {
|
||||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
int64_t *dataset_size) override;
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t batch_size_;
|
int32_t batch_size_;
|
||||||
bool drop_remainder_;
|
bool drop_remainder_;
|
||||||
|
|
|
@ -46,12 +46,40 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
|
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
|
||||||
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||||
element_length_function_, pad_info_, pad_to_bucket_boundary_);
|
element_length_function_, pad_info_, pad_to_bucket_boundary_,
|
||||||
|
drop_remainder_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BucketBatchByLengthNode::Print(std::ostream &out) const {
|
void BucketBatchByLengthNode::Print(std::ostream &out) const {
|
||||||
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
|
out << Name() + "(columns:" + PrintColumns(column_names_);
|
||||||
|
int i = 0;
|
||||||
|
for (auto it : bucket_boundaries_) {
|
||||||
|
if (i == 0) {
|
||||||
|
out << ",bucket_boundaries:{";
|
||||||
|
}
|
||||||
|
out << it;
|
||||||
|
if (i < bucket_boundaries_.size() - 1) {
|
||||||
|
out << ",";
|
||||||
|
} else {
|
||||||
|
out << "}";
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
i = 0;
|
||||||
|
for (auto it : bucket_batch_sizes_) {
|
||||||
|
if (i == 0) {
|
||||||
|
out << ",bucket_batch_sizes:{";
|
||||||
|
}
|
||||||
|
out << it;
|
||||||
|
if (i < bucket_batch_sizes_.size() - 1) {
|
||||||
|
out << ",";
|
||||||
|
} else {
|
||||||
|
out << "}";
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
out << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
|
|
|
@ -90,14 +90,14 @@ Status BuildSentenceVocabNode::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
|
Status BuildSentenceVocabNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
|
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status BuildSentenceVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
|
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,17 +59,17 @@ class BuildSentenceVocabNode : public DatasetNode {
|
||||||
/// \return Status Status::OK() if all the parameters are valid
|
/// \return Status Status::OK() if all the parameters are valid
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<SentencePieceVocab> vocab_;
|
std::shared_ptr<SentencePieceVocab> vocab_;
|
||||||
|
|
|
@ -85,14 +85,14 @@ Status BuildVocabNode::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
|
Status BuildVocabNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
|
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status BuildVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
|
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,17 +58,17 @@ class BuildVocabNode : public DatasetNode {
|
||||||
/// \return Status Status::OK() if all the parameters are valid
|
/// \return Status Status::OK() if all the parameters are valid
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Vocab> vocab_;
|
std::shared_ptr<Vocab> vocab_;
|
||||||
|
|
|
@ -39,8 +39,10 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
|
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
|
||||||
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
// create an empty vector to copy a concat
|
// create an empty vector to copy a concat
|
||||||
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
|
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler,
|
||||||
|
children_flag_and_nums_, children_start_end_index_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,14 +82,14 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status ConcatNode::Accept(NodePass *p, bool *modified) {
|
Status ConcatNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<ConcatNode>(), modified);
|
return p->Visit(shared_from_base<ConcatNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status ConcatNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
|
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,17 +66,17 @@ class ConcatNode : public DatasetNode {
|
||||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -242,9 +242,27 @@ DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) {
|
||||||
worker_connector_size_ = cfg->worker_connector_size();
|
worker_connector_size_ = cfg->worker_connector_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool DatasetNode::IsTree() const {
|
||||||
|
bool is_tree = true;
|
||||||
|
if (this->parent_.size() > 1) {
|
||||||
|
MS_LOG(WARNING) << Name() << " has more than one parent.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (const auto &child : children_) {
|
||||||
|
is_tree = child->IsTree();
|
||||||
|
if (!is_tree) {
|
||||||
|
MS_LOG(WARNING) << Name() << " has more than one parent.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return is_tree;
|
||||||
|
}
|
||||||
|
|
||||||
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
|
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
|
||||||
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
|
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
|
||||||
std::shared_ptr<DatasetNode> new_node = this->Copy();
|
std::shared_ptr<DatasetNode> new_node = this->Copy();
|
||||||
|
// temporary fix to set the num_workers to the new node.
|
||||||
|
new_node->SetNumWorkers(this->num_workers_);
|
||||||
for (const auto &child : children_) {
|
for (const auto &child : children_) {
|
||||||
new_node->AddChild(child->DeepCopy());
|
new_node->AddChild(child->DeepCopy());
|
||||||
}
|
}
|
||||||
|
@ -298,12 +316,31 @@ void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
|
||||||
children_.push_back(child);
|
children_.push_back(child);
|
||||||
child->parent_.push_back(this);
|
child->parent_.push_back(this);
|
||||||
} else if (child != nullptr) {
|
} else if (child != nullptr) {
|
||||||
MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr.";
|
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
|
||||||
children_.push_back(child);
|
children_.push_back(child);
|
||||||
child->parent_.push_back(this);
|
child->parent_.push_back(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert a node as a child of this node. This node's children becomes the children of the inserted node.
|
||||||
|
Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent.");
|
||||||
|
|
||||||
|
for (auto child : children_) {
|
||||||
|
node->children_.push_back(child);
|
||||||
|
child->parent_.clear();
|
||||||
|
child->parent_.push_back(node.get());
|
||||||
|
}
|
||||||
|
// Then establish the new parent-child relationship with the new parent.
|
||||||
|
children_.clear();
|
||||||
|
children_.push_back(node);
|
||||||
|
node->parent_.clear();
|
||||||
|
node->parent_.push_back(this);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Remove this node from its parent. Add the child of this node to its parent.
|
// Remove this node from its parent. Add the child of this node to its parent.
|
||||||
// for now, this remove is limited to node with a single child or no child
|
// for now, this remove is limited to node with a single child or no child
|
||||||
Status DatasetNode::Remove() {
|
Status DatasetNode::Remove() {
|
||||||
|
@ -325,14 +362,14 @@ Status DatasetNode::Remove() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
|
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
|
||||||
Status DatasetNode::Accept(NodePass *p, bool *modified) {
|
Status DatasetNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// This method will only be called if its derived class does not implement one.
|
// This method will only be called if its derived class does not implement one.
|
||||||
return p->Visit(shared_from_this(), modified);
|
return p->Visit(shared_from_this(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
|
// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
|
||||||
// after all child nodes are visited.
|
// after all child nodes are visited.
|
||||||
Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status DatasetNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// This method will only be called if its derived class does not implement one.
|
// This method will only be called if its derived class does not implement one.
|
||||||
return p->VisitAfter(shared_from_this(), modified);
|
return p->VisitAfter(shared_from_this(), modified);
|
||||||
}
|
}
|
||||||
|
@ -369,17 +406,5 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
|
||||||
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
|
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
|
||||||
Status SourceNode::Accept(NodePass *p, bool *modified) {
|
|
||||||
// Downcast shared pointer then call visitor
|
|
||||||
return p->Visit(shared_from_base<SourceNode>(), modified);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
|
||||||
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
|
|
||||||
// Downcast shared pointer then call visitor
|
|
||||||
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
|
|
||||||
}
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace dataset {
|
||||||
|
|
||||||
class Dataset;
|
class Dataset;
|
||||||
class SamplerObj;
|
class SamplerObj;
|
||||||
class NodePass;
|
class IRNodePass;
|
||||||
class DatasetSizeGetter;
|
class DatasetSizeGetter;
|
||||||
|
|
||||||
// Names for non-leaf IR node
|
// Names for non-leaf IR node
|
||||||
|
@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
/// \brief Establish the parent-child relationship between this node and its child.
|
/// \brief Establish the parent-child relationship between this node and its child.
|
||||||
void AddChild(std::shared_ptr<DatasetNode> child);
|
void AddChild(std::shared_ptr<DatasetNode> child);
|
||||||
|
|
||||||
|
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
|
||||||
|
Status InsertBelow(std::shared_ptr<DatasetNode> node);
|
||||||
|
|
||||||
/// \brief detach this node from its parent, add its child (if any) to its parent
|
/// \brief detach this node from its parent, add its child (if any) to its parent
|
||||||
/// \return error code, return error if node has more than 1 children
|
/// \return error code, return error if node has more than 1 children
|
||||||
Status Remove();
|
Status Remove();
|
||||||
|
@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
/// \return True if the data of this node will be cached
|
/// \return True if the data of this node will be cached
|
||||||
const bool IsCached() const { return (cache_ != nullptr); }
|
const bool IsCached() const { return (cache_ != nullptr); }
|
||||||
|
|
||||||
|
/// \brief Check if this node is a tree
|
||||||
|
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
|
||||||
|
const bool IsTree() const;
|
||||||
|
|
||||||
|
/// \brief Check if this node is a leaf node.
|
||||||
|
/// \return True if this is a leaf node.
|
||||||
|
const bool IsLeaf() const { return children_.empty(); }
|
||||||
|
|
||||||
|
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
||||||
|
/// \return True if the dataset represented by this node is a mappable dataset
|
||||||
|
const bool IsMappable() const { return mappable_; }
|
||||||
|
|
||||||
|
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||||
|
/// \return True if a cache-enabled operator is an ancestor of this node
|
||||||
|
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
|
||||||
|
|
||||||
|
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||||
|
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||||
|
|
||||||
/// \brief Setter function for runtime number of workers
|
/// \brief Setter function for runtime number of workers
|
||||||
/// \param[in] num_workers The number of threads in this operator
|
/// \param[in] num_workers The number of threads in this operator
|
||||||
/// \return Shared pointer to the original object
|
/// \return Shared pointer to the original object
|
||||||
|
@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
|
/// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up
|
||||||
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
|
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
|
||||||
/// visit on the way back up the tree after its descendants are visited.
|
/// visit on the way back up the tree after its descendants are visited.
|
||||||
/// \notes Subclass needs to override this if it requires special node visit access.
|
/// \notes Subclass needs to override this if it requires special node visit access.
|
||||||
|
@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
virtual Status Accept(NodePass *p, bool *modified);
|
virtual Status Accept(IRNodePass *p, bool *modified);
|
||||||
|
|
||||||
/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited.
|
/// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited.
|
||||||
/// \notes Subclass needs to override this if it requires special node visit access.
|
/// \notes Subclass needs to override this if it requires special node visit access.
|
||||||
/// Check "dataset/engine/opt/pass.h" for more details.
|
/// Check "dataset/engine/opt/pass.h" for more details.
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
virtual Status AcceptAfter(NodePass *p, bool *modified);
|
virtual Status AcceptAfter(IRNodePass *p, bool *modified);
|
||||||
|
|
||||||
virtual bool IsSizeDefined() { return true; }
|
virtual bool IsSizeDefined() { return true; }
|
||||||
|
|
||||||
|
@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
||||||
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
||||||
void PrintNode(std::ostream &out, int *level) const;
|
void PrintNode(std::ostream &out, int *level) const;
|
||||||
};
|
|
||||||
|
|
||||||
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
|
|
||||||
class SourceNode : public DatasetNode {
|
|
||||||
public:
|
|
||||||
/// \brief Constructor
|
|
||||||
SourceNode() : DatasetNode() {}
|
|
||||||
|
|
||||||
/// \brief Constructor that initializes the cache
|
|
||||||
/// \param dataset_cache DatasetCache
|
|
||||||
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
|
|
||||||
|
|
||||||
/// \brief Destructor
|
|
||||||
~SourceNode() = default;
|
|
||||||
|
|
||||||
/// \brief Node name getter
|
|
||||||
/// \return Name of the current node
|
|
||||||
virtual std::string Name() const = 0;
|
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
|
||||||
/// \param[in] p The node to visit
|
|
||||||
/// \param[out] modified Indicator if the node was modified
|
|
||||||
/// \return Status of the node visit
|
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
|
||||||
/// \param[in] p The node to visit
|
|
||||||
/// \param[out] modified Indicator if the node was modified
|
|
||||||
/// \return Status of the node visit
|
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
|
||||||
|
|
||||||
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
|
||||||
/// \return True if the dataset represented by this node is a mappable dataset
|
|
||||||
const bool IsMappable() const { return mappable_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool mappable_;
|
bool mappable_;
|
||||||
|
bool descendant_of_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
|
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
|
||||||
class MappableSourceNode : public SourceNode {
|
class MappableSourceNode : public DatasetNode {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
MappableSourceNode() : SourceNode() { mappable_ = true; }
|
MappableSourceNode() : DatasetNode() { mappable_ = true; }
|
||||||
|
|
||||||
/// \brief Constructor that initializes the cache
|
/// \brief Constructor that initializes the cache
|
||||||
/// \param dataset_cache DatasetCache
|
/// \param dataset_cache DatasetCache
|
||||||
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
||||||
mappable_ = true;
|
mappable_ = true;
|
||||||
|
// Initially set to false, and set to true by the optimizer when conditions are met.
|
||||||
|
descendant_of_cache_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
|
@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode {
|
||||||
};
|
};
|
||||||
|
|
||||||
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
|
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
|
||||||
class NonMappableSourceNode : public SourceNode {
|
class NonMappableSourceNode : public DatasetNode {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
NonMappableSourceNode() : SourceNode() { mappable_ = false; }
|
NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
|
||||||
|
|
||||||
/// \brief Constructor that initializes the cache
|
/// \brief Constructor that initializes the cache
|
||||||
/// \param dataset_cache DatasetCache
|
/// \param dataset_cache DatasetCache
|
||||||
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
||||||
mappable_ = false;
|
mappable_ = false;
|
||||||
|
// Initially set to false, and set to true by the optimizer when conditions are met.
|
||||||
|
descendant_of_cache_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
|
@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode {
|
||||||
/// \return Name of the current node
|
/// \return Name of the current node
|
||||||
virtual std::string Name() const = 0;
|
virtual std::string Name() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// NonLeafNode represents operations over data in a pipeline.
|
|
||||||
class NonLeafNode : public DatasetNode {
|
|
||||||
public:
|
|
||||||
/// \brief Constructor
|
|
||||||
NonLeafNode() = default;
|
|
||||||
|
|
||||||
/// \brief Destructor
|
|
||||||
~NonLeafNode() = default;
|
|
||||||
|
|
||||||
/// \brief Node name getter
|
|
||||||
/// \return Name of the current node
|
|
||||||
virtual std::string Name() const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// SinkNode represents the end node of a pipeline where the data is pushed out
|
|
||||||
class SinkNode : public DatasetNode {
|
|
||||||
public:
|
|
||||||
/// \brief Constructor
|
|
||||||
SinkNode() = default;
|
|
||||||
|
|
||||||
/// \brief Destructor
|
|
||||||
~SinkNode() = default;
|
|
||||||
|
|
||||||
/// \brief Node name getter
|
|
||||||
/// \return Name of the current node
|
|
||||||
virtual std::string Name() const = 0;
|
|
||||||
};
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
|
||||||
|
|
|
@ -32,8 +32,9 @@ EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epo
|
||||||
// The root node's parent must set to null pointer.
|
// The root node's parent must set to null pointer.
|
||||||
this->AddChild(child);
|
this->AddChild(child);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
|
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
|
||||||
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
|
auto node = std::make_shared<EpochCtrlNode>(num_epochs_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,10 @@ namespace dataset {
|
||||||
class EpochCtrlNode : public DatasetNode {
|
class EpochCtrlNode : public DatasetNode {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {}
|
||||||
|
|
||||||
|
/// \brief Constructor
|
||||||
|
EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~EpochCtrlNode() = default;
|
~EpochCtrlNode() = default;
|
||||||
|
|
|
@ -60,14 +60,14 @@ Status FilterNode::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status FilterNode::Accept(NodePass *p, bool *modified) {
|
Status FilterNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<FilterNode>(), modified);
|
return p->Visit(shared_from_base<FilterNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status FilterNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
|
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,17 +58,17 @@ class FilterNode : public DatasetNode {
|
||||||
|
|
||||||
bool IsSizeDefined() override { return false; };
|
bool IsSizeDefined() override { return false; };
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<TensorOp> predicate_;
|
std::shared_ptr<TensorOp> predicate_;
|
||||||
|
|
|
@ -42,14 +42,16 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> MapNode::Copy() {
|
std::shared_ptr<DatasetNode> MapNode::Copy() {
|
||||||
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
|
std::vector<std::shared_ptr<TensorOperation>> operations = operations_;
|
||||||
|
auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_,
|
||||||
callbacks_);
|
callbacks_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MapNode::Print(std::ostream &out) const {
|
void MapNode::Print(std::ostream &out) const {
|
||||||
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
|
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
|
||||||
",<project_cols>" + ",...)";
|
",<project_cols>" + ",num_tensor_ops:"
|
||||||
|
<< operations_.size() << ",...)";
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
|
@ -101,14 +103,14 @@ Status MapNode::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status MapNode::Accept(NodePass *p, bool *modified) {
|
Status MapNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<MapNode>(), modified);
|
return p->Visit(shared_from_base<MapNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<MapNode>(), modified);
|
return p->VisitAfter(shared_from_base<MapNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,17 +63,17 @@ class MapNode : public DatasetNode {
|
||||||
const auto &TensorOperations() const { return operations_; }
|
const auto &TensorOperations() const { return operations_; }
|
||||||
auto &TensorOperations() { return operations_; }
|
auto &TensorOperations() { return operations_; }
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||||
|
|
|
@ -70,14 +70,14 @@ Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status RepeatNode::Accept(NodePass *p, bool *modified) {
|
Status RepeatNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<RepeatNode>(), modified);
|
return p->Visit(shared_from_base<RepeatNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status RepeatNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
|
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,17 +66,17 @@ class RepeatNode : public DatasetNode {
|
||||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
int64_t *dataset_size) override;
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t repeat_count_;
|
int32_t repeat_count_;
|
||||||
|
|
|
@ -72,14 +72,14 @@ Status RootNode::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status RootNode::Accept(NodePass *p, bool *modified) {
|
Status RootNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<RootNode>(), modified);
|
return p->Visit(shared_from_base<RootNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status RootNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status RootNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<RootNode>(), modified);
|
return p->VisitAfter(shared_from_base<RootNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,17 +58,17 @@ class RootNode : public DatasetNode {
|
||||||
/// \return Status Status::OK() if all the parameters are valid
|
/// \return Status Status::OK() if all the parameters are valid
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t num_epochs_;
|
int32_t num_epochs_;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -70,5 +71,16 @@ Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status SkipNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->Visit(shared_from_base<SkipNode>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status SkipNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->VisitAfter(shared_from_base<SkipNode>(), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -64,6 +64,18 @@ class SkipNode : public DatasetNode {
|
||||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
int64_t *dataset_size) override;
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t skip_count_;
|
int32_t skip_count_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch
|
||||||
sampler_(sampler) {}
|
sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> AlbumNode::Copy() {
|
std::shared_ptr<DatasetNode> AlbumNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
|
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
|
||||||
extensions_(extensions) {}
|
extensions_(extensions) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> CelebANode::Copy() {
|
std::shared_ptr<DatasetNode> CelebANode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
|
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us
|
||||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
|
std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
|
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag
|
||||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
|
std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
|
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -208,7 +208,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(clue_op->Init());
|
RETURN_IF_NOT_OK(clue_op->Init());
|
||||||
|
|
||||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||||
// Inject ShuffleOp
|
// Inject ShuffleOp
|
||||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
int64_t num_rows = 0;
|
int64_t num_rows = 0;
|
||||||
|
|
|
@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation
|
||||||
sampler_(sampler) {}
|
sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> CocoNode::Copy() {
|
std::shared_ptr<DatasetNode> CocoNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
|
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,7 +119,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(csv_op->Init());
|
RETURN_IF_NOT_OK(csv_op->Init());
|
||||||
|
|
||||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||||
// Inject ShuffleOp
|
// Inject ShuffleOp
|
||||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
int64_t num_rows = 0;
|
int64_t num_rows = 0;
|
||||||
|
|
|
@ -33,8 +33,16 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
|
||||||
column_names_(column_names),
|
column_names_(column_names),
|
||||||
column_types_(column_types) {}
|
column_types_(column_types) {}
|
||||||
|
|
||||||
|
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
|
||||||
|
: generator_function_(generator_function), schema_(schema) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
||||||
auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
|
std::shared_ptr<GeneratorNode> node;
|
||||||
|
if (schema_ == nullptr) {
|
||||||
|
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
|
||||||
|
}
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,9 +50,6 @@ void GeneratorNode::Print(std::ostream &out) const {
|
||||||
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
|
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
|
||||||
}
|
}
|
||||||
|
|
||||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
|
|
||||||
: generator_function_(generator_function), schema_(schema) {}
|
|
||||||
|
|
||||||
Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
|
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
|
||||||
exts_(extensions) {}
|
exts_(extensions) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
|
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node =
|
auto node =
|
||||||
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
|
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
|
||||||
return node;
|
return node;
|
||||||
|
|
|
@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
|
||||||
sampler_(sampler) {}
|
sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> ManifestNode::Copy() {
|
std::shared_ptr<DatasetNode> ManifestNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
|
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,12 +54,13 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> MindDataNode::Copy() {
|
std::shared_ptr<DatasetNode> MindDataNode::Copy() {
|
||||||
std::shared_ptr<MindDataNode> node;
|
std::shared_ptr<MindDataNode> node;
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
if (dataset_files_.empty()) {
|
if (dataset_files_.empty()) {
|
||||||
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
|
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
|
||||||
} else {
|
} else {
|
||||||
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
|
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
|
||||||
}
|
}
|
||||||
|
node->SetSampleBytes(&sample_bytes_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr
|
||||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> MnistNode::Copy() {
|
std::shared_ptr<DatasetNode> MnistNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
|
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||||
|
|
||||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||||
// Inject ShuffleOp
|
// Inject ShuffleOp
|
||||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
int64_t num_rows = 0;
|
int64_t num_rows = 0;
|
||||||
|
|
|
@ -134,7 +134,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(tf_reader_op->Init());
|
RETURN_IF_NOT_OK(tf_reader_op->Init());
|
||||||
|
|
||||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||||
// Inject ShuffleOp
|
// Inject ShuffleOp
|
||||||
|
|
||||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
|
|
|
@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
|
||||||
sampler_(sampler) {}
|
sampler_(sampler) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> VOCNode::Copy() {
|
std::shared_ptr<DatasetNode> VOCNode::Copy() {
|
||||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||||
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
|
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/datasetops/take_op.h"
|
#include "minddata/dataset/engine/datasetops/take_op.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -68,5 +69,16 @@ Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status TakeNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->Visit(shared_from_base<TakeNode>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accepting method for IRNodePass
|
||||||
|
Status TakeNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->VisitAfter(shared_from_base<TakeNode>(), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -64,6 +64,18 @@ class TakeNode : public DatasetNode {
|
||||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
int64_t *dataset_size) override;
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t take_count_;
|
int32_t take_count_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -104,14 +104,14 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status TransferNode::Accept(NodePass *p, bool *modified) {
|
Status TransferNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<TransferNode>(), modified);
|
return p->Visit(shared_from_base<TransferNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status TransferNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status TransferNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<TransferNode>(), modified);
|
return p->VisitAfter(shared_from_base<TransferNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,17 +58,17 @@ class TransferNode : public DatasetNode {
|
||||||
|
|
||||||
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
|
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
/// \param[in] p The node to visit
|
/// \param[in] p The node to visit
|
||||||
/// \param[out] modified Indicator if the node was modified
|
/// \param[out] modified Indicator if the node was modified
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string queue_name_;
|
std::string queue_name_;
|
||||||
|
|
|
@ -79,14 +79,14 @@ Status ZipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status ZipNode::Accept(NodePass *p, bool *modified) {
|
Status ZipNode::Accept(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->Visit(shared_from_base<ZipNode>(), modified);
|
return p->Visit(shared_from_base<ZipNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor accepting method for NodePass
|
// Visitor accepting method for IRNodePass
|
||||||
Status ZipNode::AcceptAfter(NodePass *p, bool *modified) {
|
Status ZipNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
|
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,19 +64,20 @@ class ZipNode : public DatasetNode {
|
||||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
int64_t *dataset_size) override;
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::shared_ptr<DatasetNode>> datasets_;
|
std::vector<std::shared_ptr<DatasetNode>> datasets_;
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
|
||||||
/// \param[in] p The node to visit
|
|
||||||
/// \param[out] modified Indicator if the node was modified
|
|
||||||
/// \return Status of the node visit
|
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
|
||||||
|
|
||||||
/// \brief Base-class override for accepting NodePass visitor
|
|
||||||
/// \param[in] p The node to visit
|
|
||||||
/// \param[out] modified Indicator if the node was modified
|
|
||||||
/// \return Status of the node visit
|
|
||||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -6,9 +6,12 @@ add_library(engine-opt OBJECT
|
||||||
post/repeat_pass.cc
|
post/repeat_pass.cc
|
||||||
pre/cache_error_pass.cc
|
pre/cache_error_pass.cc
|
||||||
pre/cache_transform_pass.cc
|
pre/cache_transform_pass.cc
|
||||||
|
pre/cache_validation_pass.cc
|
||||||
|
pre/epoch_ctrl_pass.cc
|
||||||
pre/epoch_injection_pass.cc
|
pre/epoch_injection_pass.cc
|
||||||
pre/getter_pass.cc
|
pre/getter_pass.cc
|
||||||
pre/input_validation_pass.cc
|
pre/input_validation_pass.cc
|
||||||
|
pre/node_removal_pass.cc
|
||||||
pre/removal_pass.cc
|
pre/removal_pass.cc
|
||||||
util/printer_pass.cc
|
util/printer_pass.cc
|
||||||
)
|
)
|
||||||
|
|
|
@ -87,7 +87,7 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Driver method for TreePass
|
// Driver method for TreePass
|
||||||
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
if (root_ir == nullptr || modified == nullptr) {
|
if (root_ir == nullptr || modified == nullptr) {
|
||||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver method for NodePass
|
// Driver method for NodePass
|
||||||
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
if (root_ir == nullptr || modified == nullptr) {
|
if (root_ir == nullptr || modified == nullptr) {
|
||||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to perform DFS visit
|
// Helper function to perform DFS visit
|
||||||
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
||||||
bool m = false;
|
bool m = false;
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
|
RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
|
||||||
|
@ -125,7 +125,7 @@ Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to perform BFS visit
|
// Helper function to perform BFS visit
|
||||||
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
||||||
bool m = false;
|
bool m = false;
|
||||||
|
|
||||||
// Initialize bfs queue with root
|
// Initialize bfs queue with root
|
||||||
|
@ -151,121 +151,113 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
|
||||||
}
|
}
|
||||||
|
|
||||||
// For non-leaf IR node
|
// For non-leaf IR node
|
||||||
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_PYTHON
|
#ifdef ENABLE_PYTHON
|
||||||
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// For leaf IR Node
|
|
||||||
Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) {
|
|
||||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
|
||||||
}
|
|
||||||
Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) {
|
|
||||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////
|
//////////////////////////////////
|
||||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||||
// Driver method for TreePass
|
// Driver method for TreePass
|
||||||
|
|
|
@ -113,26 +113,18 @@ class GeneratorOp;
|
||||||
|
|
||||||
// The base class Pass is the basic unit of tree transformation.
|
// The base class Pass is the basic unit of tree transformation.
|
||||||
// The actual implementation of the passes will be derived from here.
|
// The actual implementation of the passes will be derived from here.
|
||||||
class Pass : public std::enable_shared_from_this<Pass> {
|
class IRPass : public std::enable_shared_from_this<IRPass> {
|
||||||
public:
|
public:
|
||||||
// Run the transformation pass against the IR tree.
|
// Run the transformation pass against the IR tree.
|
||||||
// @param root_ir - Pointer to the IR tree to be transformed.
|
// @param root_ir - Pointer to the IR tree to be transformed.
|
||||||
// @param modified - Pointer to the modified flag,
|
// @param modified - Pointer to the modified flag,
|
||||||
virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0;
|
virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0;
|
||||||
|
|
||||||
//////////////////////////////////
|
virtual ~IRPass() = default;
|
||||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
|
||||||
// Run the transformation pass against the execution tree.
|
|
||||||
// @param tree - Pointer to the execution tree to be transformed.
|
|
||||||
// @param modified - Pointer to the modified flag,
|
|
||||||
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
virtual ~Pass() = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
// IRTreePass is a basic Pass class which performs transformation on IR tree directly.
|
||||||
class TreePass : public Pass {
|
class IRTreePass : public IRPass {
|
||||||
public:
|
public:
|
||||||
/// \brief Run the transformation pass against the IR tree.
|
/// \brief Run the transformation pass against the IR tree.
|
||||||
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
|
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
|
||||||
|
@ -145,44 +137,29 @@ class TreePass : public Pass {
|
||||||
/// \param[inout] Indicate if the tree was modified.
|
/// \param[inout] Indicate if the tree was modified.
|
||||||
/// \return Status The error code return
|
/// \return Status The error code return
|
||||||
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
|
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
|
||||||
/// \brief Run the transformation pass against the execution tree.
|
|
||||||
/// \param[inout] tree Pointer to the execution tree to be transformed.
|
|
||||||
/// \param[inout] modified Indicate if the tree was modified
|
|
||||||
Status Run(ExecutionTree *tree, bool *modified) final;
|
|
||||||
|
|
||||||
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
|
||||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
|
||||||
/// \param[inout] tree The tree to operate on.
|
|
||||||
/// \param[inout] Indicate of the tree was modified.
|
|
||||||
/// \return Status The error code return
|
|
||||||
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
|
||||||
//////////////////////////////////
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// NodePass is a base Pass class which performs transformation on node visiting.
|
// IRNodePass is a base Pass class which performs transformation on node visiting.
|
||||||
// NodePass implements Visitor design pattern.
|
// IRNodePass implements Visitor design pattern.
|
||||||
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
|
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
|
||||||
// and the other when all the descending nodes are visited.
|
// and the other when all the descending nodes are visited.
|
||||||
// Actual transformation is done by implementing a new derived class of NodePass.
|
// Actual transformation is done by implementing a new derived class of IRNodePass.
|
||||||
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
|
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
|
||||||
// it wants to action on them, overriding the ones defined in NodePass.
|
// it wants to action on them, overriding the ones defined in IRNodePass.
|
||||||
// If the derived class wants to perform the same action on all node types,
|
// If the derived class wants to perform the same action on all node types,
|
||||||
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
|
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
|
||||||
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
|
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
|
||||||
// to call the Visit()/VisitAfter() in this parent NodePass class.
|
// to call the Visit()/VisitAfter() in this parent IRNodePass class.
|
||||||
class NodePass : public Pass {
|
class IRNodePass : public IRPass {
|
||||||
public:
|
public:
|
||||||
// Tree traversal order
|
// Tree traversal order
|
||||||
enum Order { DFS, BFS };
|
enum Order { DFS, BFS };
|
||||||
|
|
||||||
// Constructor
|
// Constructor
|
||||||
// Default DFS traversal
|
// Default DFS traversal
|
||||||
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
|
explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; }
|
||||||
|
|
||||||
~NodePass() = default;
|
~IRNodePass() = default;
|
||||||
|
|
||||||
/// \brief Run the transformation pass against the IR tree
|
/// \brief Run the transformation pass against the IR tree
|
||||||
/// \param[inout] root_ir Pointer to the IR tree to be transformed
|
/// \param[inout] root_ir Pointer to the IR tree to be transformed
|
||||||
|
@ -251,12 +228,70 @@ class NodePass : public Pass {
|
||||||
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||||
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||||
#endif
|
#endif
|
||||||
// Leaf IR node
|
|
||||||
virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified);
|
|
||||||
virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
private:
|
||||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
// Helper function to perform DFS visit
|
||||||
|
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
|
||||||
|
|
||||||
|
// Helper function to perform BFS visit
|
||||||
|
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
|
||||||
|
|
||||||
|
// Tree traversal order of the NodePass
|
||||||
|
Order traversalOrder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////
|
||||||
|
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||||
|
// The base class Pass is the basic unit of tree transformation.
|
||||||
|
// The actual implementation of the passes will be derived from here.
|
||||||
|
class Pass : public std::enable_shared_from_this<Pass> {
|
||||||
|
public:
|
||||||
|
// Run the transformation pass against the execution tree.
|
||||||
|
// @param tree - Pointer to the execution tree to be transformed.
|
||||||
|
// @param modified - Pointer to the modified flag,
|
||||||
|
virtual Status Run(ExecutionTree *tree, bool *modified) = 0;
|
||||||
|
|
||||||
|
virtual ~Pass() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
||||||
|
class TreePass : public Pass {
|
||||||
|
public:
|
||||||
|
/// \brief Run the transformation pass against the execution tree.
|
||||||
|
/// \param[inout] tree Pointer to the execution tree to be transformed.
|
||||||
|
/// \param[inout] modified Indicate if the tree was modified
|
||||||
|
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||||
|
|
||||||
|
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
||||||
|
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||||
|
/// \param[inout] tree The tree to operate on.
|
||||||
|
/// \param[inout] Indicate of the tree was modified.
|
||||||
|
/// \return Status The error code return
|
||||||
|
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// NodePass is a base Pass class which performs transformation on node visiting.
|
||||||
|
// NodePass implements Visitor design pattern.
|
||||||
|
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
|
||||||
|
// and the other when all the descending nodes are visited.
|
||||||
|
// Actual transformation is done by implementing a new derived class of NodePass.
|
||||||
|
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
|
||||||
|
// it wants to action on them, overriding the ones defined in NodePass.
|
||||||
|
// If the derived class wants to perform the same action on all node types,
|
||||||
|
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
|
||||||
|
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
|
||||||
|
// to call the Visit()/VisitAfter() in this parent NodePass class.
|
||||||
|
class NodePass : public Pass {
|
||||||
|
public:
|
||||||
|
// Tree traversal order
|
||||||
|
enum Order { DFS, BFS };
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
// Default DFS traversal
|
||||||
|
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
|
||||||
|
|
||||||
|
~NodePass() = default;
|
||||||
|
|
||||||
/// \brief Run the transformation pass against the execution tree
|
/// \brief Run the transformation pass against the execution tree
|
||||||
/// \param[inout] tree Pointer to the execution tree to be transformed
|
/// \param[inout] tree Pointer to the execution tree to be transformed
|
||||||
/// \param[inout] modified Indicator if the tree was changed
|
/// \param[inout] modified Indicator if the tree was changed
|
||||||
|
@ -326,27 +361,18 @@ class NodePass : public Pass {
|
||||||
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
||||||
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||||
#endif
|
#endif
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Helper function to perform DFS visit
|
|
||||||
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
|
|
||||||
|
|
||||||
// Helper function to perform BFS visit
|
|
||||||
Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
|
||||||
// Helper function to perform DFS visit
|
// Helper function to perform DFS visit
|
||||||
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||||
|
|
||||||
// Helper function to perform BFS visit
|
// Helper function to perform BFS visit
|
||||||
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
|
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
// Tree traversal order of the NodePass
|
// Tree traversal order of the NodePass
|
||||||
Order traversalOrder_;
|
Order traversalOrder_;
|
||||||
};
|
};
|
||||||
|
//////////////////////////////////
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,163 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||||
|
#include "minddata/dataset/include/transforms.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {}
|
||||||
|
|
||||||
|
// Returns an error if BatchNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<BatchNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if ConcatNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ConcatNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if FilterNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<FilterNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if SkipNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<SkipNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if TakeNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<TakeNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if ZipNode exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ZipNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache.");
|
||||||
|
}
|
||||||
|
if (node->IsCached()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if MapNode with non-deterministic tensor operations exists under a cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<MapNode>): visiting " << node->Name() << ".";
|
||||||
|
if (node->IsCached()) {
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported.");
|
||||||
|
}
|
||||||
|
// If Map is created to be cached, set the flag indicating we found an operation with a cache.
|
||||||
|
is_cached_ = true;
|
||||||
|
auto tfuncs = node->TensorOperations();
|
||||||
|
for (size_t i = 0; i < tfuncs.size(); i++) {
|
||||||
|
if (tfuncs[i]->IsRandomOp()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED(
|
||||||
|
"MapNode with non-deterministic operations is not supported as a descendant of cache.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flag an error if we have a cache over another cache
|
||||||
|
Status CacheValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::Visit(<DatasetNode>): visiting " << node->Name() << ".";
|
||||||
|
if (node->IsCached()) {
|
||||||
|
if (is_cached_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported.");
|
||||||
|
}
|
||||||
|
// If this node is created to be cached, set the flag.
|
||||||
|
is_cached_ = true;
|
||||||
|
}
|
||||||
|
if (node->IsLeaf() && node->IsMappable()) {
|
||||||
|
is_mappable_ = true;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an error if MappableSource <- Repeat <- Node with a cache
|
||||||
|
// Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem.
|
||||||
|
Status CacheValidationPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<RepeatNode>): visiting " << node->Name() << ".";
|
||||||
|
if (is_cached_ && is_mappable_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheValidationPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||||
|
MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<DatasetNode>): visiting " << node->Name() << ".";
|
||||||
|
// Reset the flag when all descendants are visited
|
||||||
|
if (node->IsCached()) {
|
||||||
|
is_cached_ = false;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <stack>
|
||||||
|
#include <utility>
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// \class CacheValidationPass cache_validation_pass.h
|
||||||
|
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
|
||||||
|
class CacheValidationPass : public IRNodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
CacheValidationPass();
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~CacheValidationPass() = default;
|
||||||
|
|
||||||
|
/// \brief Returns an error if BatchNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if ConcatNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if FilterNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if SkipNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if TakeNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if ZipNode exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Returns an error if there is a cache over another cache
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Identifies and block repeat under cache scenarios
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree above this node as not being cached
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_cached_;
|
||||||
|
bool is_mappable_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_
|
|
@ -0,0 +1,85 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// constructor
|
||||||
|
EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetNode> node)
|
||||||
|
: injection_point_(nullptr), num_epochs_(-1) {}
|
||||||
|
|
||||||
|
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
|
||||||
|
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *modified) {
|
||||||
|
// The injection is at the child of the root node
|
||||||
|
injection_point_ = node;
|
||||||
|
num_epochs_ = node->num_epochs();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
|
||||||
|
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
||||||
|
injection_point_ = nullptr;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
// Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection
|
||||||
|
Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||||
|
injection_point_ = nullptr;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||||
|
// Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
|
||||||
|
// Move the injection point to the child of this node.
|
||||||
|
injection_point_ = node;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructor
|
||||||
|
EpochCtrlPass::EpochCtrlPass() {}
|
||||||
|
|
||||||
|
// Runs an injection pass to inject in operators needed at the pre pass stage
|
||||||
|
Status EpochCtrlPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
|
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||||
|
|
||||||
|
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||||
|
// The finder can make updates to the EpochInjectionPass object.
|
||||||
|
EpochCtrlPass::InjectionFinder finder(root_ir);
|
||||||
|
RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
|
||||||
|
|
||||||
|
// The first injection logic is to check if we should inject the epoch control op as the root node.
|
||||||
|
// Do not inject the op if the number of epochs is 1.
|
||||||
|
std::shared_ptr<DatasetNode> parent = finder.injection_point();
|
||||||
|
int32_t num_epochs = finder.num_epochs();
|
||||||
|
if (num_epochs != 1 && parent != nullptr) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(parent->Children().size() == 1, "EpochCtrl must be injected on only one child.");
|
||||||
|
auto epoch_ctrl_node = std::make_shared<EpochCtrlNode>(num_epochs);
|
||||||
|
RETURN_IF_NOT_OK(parent->InsertBelow(epoch_ctrl_node));
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
/// \class EpochInjectionPass epoch_ctrl_pass.h
|
||||||
|
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
|
||||||
|
/// parsing.
|
||||||
|
class EpochCtrlPass : public IRTreePass {
|
||||||
|
/// \class InjectionFinder
|
||||||
|
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
|
||||||
|
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
|
||||||
|
/// it may need to inject.
|
||||||
|
class InjectionFinder : public IRNodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
explicit InjectionFinder(std::shared_ptr<DatasetNode> node);
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~InjectionFinder() = default;
|
||||||
|
|
||||||
|
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<RootNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/// \brief Register the TransferNode for further action.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Getter
|
||||||
|
std::shared_ptr<DatasetNode> injection_point() { return injection_point_; }
|
||||||
|
|
||||||
|
/// \brief Getter
|
||||||
|
int32_t num_epochs() { return num_epochs_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<DatasetNode> injection_point_;
|
||||||
|
int32_t num_epochs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
EpochCtrlPass();
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~EpochCtrlPass() = default;
|
||||||
|
|
||||||
|
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||||
|
/// \param[inout] tree The tree to operate on.
|
||||||
|
/// \param[inout] Indicate of the tree was modified.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
|
@ -26,7 +26,7 @@ namespace dataset {
|
||||||
|
|
||||||
/// \class InputValidationPass
|
/// \class InputValidationPass
|
||||||
/// \brief This is a parse pass that validates input parameters of the IR tree.
|
/// \brief This is a parse pass that validates input parameters of the IR tree.
|
||||||
class InputValidationPass : public NodePass {
|
class InputValidationPass : public IRNodePass {
|
||||||
/// \brief Runs a validatation pass to check input parameters
|
/// \brief Runs a validatation pass to check input parameters
|
||||||
/// \param[in] node The node being visited
|
/// \param[in] node The node being visited
|
||||||
/// \param[inout] *modified indicates whether the node has been visited
|
/// \param[inout] *modified indicates whether the node has been visited
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
NodeRemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
|
||||||
|
|
||||||
|
// Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
MS_LOG(INFO) << "Node removal pass: Operation with cache found, identified descendant tree.";
|
||||||
|
if (node->IsCached()) {
|
||||||
|
is_caching_ = true;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the tracking of the cache within the tree
|
||||||
|
Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
MS_LOG(INFO) << "Removal pass: Descendant walk is complete.";
|
||||||
|
if (is_caching_ && node->IsLeaf()) {
|
||||||
|
// Mark this leaf node to indicate it is a descendant of an operator with cache.
|
||||||
|
// This is currently used in non-mappable dataset (leaf) nodes to not add a ShuffleOp in DatasetNode::Build().
|
||||||
|
node->HasCacheAbove();
|
||||||
|
}
|
||||||
|
is_caching_ = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform ShuffleOp removal check.
|
||||||
|
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
#if 0
|
||||||
|
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||||
|
if (is_caching_) {
|
||||||
|
MS_LOG(INFO) << "Shuffle under an operation with cache is identified for removal.";
|
||||||
|
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructor
|
||||||
|
NodeRemovalPass::NodeRemovalPass() {}
|
||||||
|
|
||||||
|
// Walk the tree to collect the nodes to remove, then removes them.
|
||||||
|
Status NodeRemovalPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||||
|
MS_LOG(INFO) << "Pre pass: node removal pass started.";
|
||||||
|
// Create the removal node pass which can identify which nodes need to be removed.
|
||||||
|
std::unique_ptr<NodeRemovalPass::RemovalNodes> removal_nodes = std::make_unique<NodeRemovalPass::RemovalNodes>();
|
||||||
|
RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified));
|
||||||
|
|
||||||
|
// Then, execute the removal of any nodes that were set up for removal
|
||||||
|
for (auto node : removal_nodes->nodes_to_remove()) {
|
||||||
|
RETURN_IF_NOT_OK(node->Remove());
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Pre pass: node removal pass complete.";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
/// \class RemovalPass removal_pass.h
|
||||||
|
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||||
|
/// nodes should be removed, and then removes them.
|
||||||
|
class NodeRemovalPass : public IRTreePass {
|
||||||
|
/// \class RemovalNodes
|
||||||
|
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||||
|
/// It works in conjunction with the removal_pass.
|
||||||
|
class RemovalNodes : public IRNodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||||
|
RemovalNodes();
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~RemovalNodes() = default;
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Resets the tracking of the cache within the tree
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform ShuffleNode removal check
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Getter
|
||||||
|
/// \return All the nodes to be removed
|
||||||
|
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_caching_;
|
||||||
|
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
NodeRemovalPass();
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~NodeRemovalPass() = default;
|
||||||
|
|
||||||
|
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||||
|
/// \param[inout] tree The tree to operate on.
|
||||||
|
/// \param[inout] Indicate of the tree was modified.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
|
|
@ -17,34 +17,25 @@
|
||||||
#include "minddata/dataset/engine/tree_adapter.h"
|
#include "minddata/dataset/engine/tree_adapter.h"
|
||||||
|
|
||||||
#include "minddata/dataset/core/client.h"
|
#include "minddata/dataset/core/client.h"
|
||||||
#include "minddata/dataset/include/datasets.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
||||||
// Vector of actions in validation pass
|
// Vector of actions in pre-pass phase
|
||||||
std::vector<std::unique_ptr<NodePass>> validations;
|
std::vector<std::unique_ptr<IRPass>> actions;
|
||||||
|
|
||||||
MS_LOG(INFO) << "Running pre pass loops.";
|
MS_LOG(INFO) << "Running pre pass loops.";
|
||||||
validations.push_back(std::make_unique<InputValidationPass>());
|
actions.push_back(std::make_unique<InputValidationPass>());
|
||||||
|
actions.push_back(std::make_unique<CacheValidationPass>());
|
||||||
// Vector of flags for each action
|
actions.push_back(std::make_unique<NodeRemovalPass>());
|
||||||
// Apply validation actions
|
actions.push_back(std::make_unique<EpochCtrlPass>());
|
||||||
for (auto i = 0; i < validations.size(); i++) {
|
|
||||||
auto modified = false;
|
|
||||||
// InputValidationPass does not change the IR tree. We don't need to capture the "modified" value.
|
|
||||||
RETURN_IF_NOT_OK(validations[i]->Run(ir, &modified));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Vector of actions in pre-pass phase
|
|
||||||
std::vector<std::unique_ptr<Pass>> actions;
|
|
||||||
|
|
||||||
// We will gradually move CacheErrorPass, EpochInjectionPass, CacheTransformPass
|
|
||||||
// from ExecutionTree::PrepareTreePreAction to here.
|
|
||||||
|
|
||||||
// Vector of flags for each action
|
// Vector of flags for each action
|
||||||
std::vector<bool> modified(actions.size(), false);
|
std::vector<bool> modified(actions.size(), false);
|
||||||
|
@ -60,7 +51,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
||||||
|
|
||||||
Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
|
Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
|
||||||
// Vector of optimizations
|
// Vector of optimizations
|
||||||
std::vector<std::unique_ptr<NodePass>> optimizations;
|
std::vector<std::unique_ptr<IRNodePass>> optimizations;
|
||||||
MS_LOG(INFO) << "Running optimization pass loops";
|
MS_LOG(INFO) << "Running optimization pass loops";
|
||||||
|
|
||||||
// We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here.
|
// We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here.
|
||||||
|
@ -79,7 +70,7 @@ Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
|
||||||
|
|
||||||
Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
||||||
// Vector of actions in post-pass phase
|
// Vector of actions in post-pass phase
|
||||||
std::vector<std::unique_ptr<Pass>> actions;
|
std::vector<std::unique_ptr<IRPass>> actions;
|
||||||
MS_LOG(INFO) << "Running post pass loops.";
|
MS_LOG(INFO) << "Running post pass loops.";
|
||||||
|
|
||||||
// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.
|
// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.
|
||||||
|
@ -96,10 +87,6 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
||||||
// Check if pipeline is valid or not
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(ir->Parent().size() <= 1,
|
|
||||||
"The data pipeline is not a tree (i.e. one node has two consumers)");
|
|
||||||
|
|
||||||
// Build the DatasetOp ExecutionTree from the optimized IR tree
|
// Build the DatasetOp ExecutionTree from the optimized IR tree
|
||||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||||
RETURN_IF_NOT_OK(ir->Build(&ops));
|
RETURN_IF_NOT_OK(ir->Build(&ops));
|
||||||
|
@ -130,8 +117,12 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
||||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||||
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
||||||
|
|
||||||
|
// We will first walk the input tree to sanity check this is not a graph
|
||||||
|
// Flag an error when it is not a tree
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)");
|
||||||
|
|
||||||
// Copy the input IR tree and insert under the root node
|
// Copy the input IR tree and insert under the root node
|
||||||
// Create a root node to host the input IR tree, the deepcopied tree will be passed to optimization pass
|
// Create a root node to host the new copy of the input IR tree to pass to the optimizer
|
||||||
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
|
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
|
||||||
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
|
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
|
||||||
|
|
||||||
|
@ -151,11 +142,9 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
||||||
// This will evolve in the long run
|
// This will evolve in the long run
|
||||||
tree_ = std::make_unique<ExecutionTree>();
|
tree_ = std::make_unique<ExecutionTree>();
|
||||||
|
|
||||||
// Build the Execution tree from the child of the root node
|
// Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree
|
||||||
std::shared_ptr<DatasetOp> root_op;
|
std::shared_ptr<DatasetOp> root_op;
|
||||||
// input_ir is the ir node before the deepcopy.
|
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op));
|
||||||
// We will replace input_ir with root_ir->Children()[0] once IR optimizer is in
|
|
||||||
RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op));
|
|
||||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||||
|
|
||||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
||||||
|
@ -163,7 +152,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
||||||
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
||||||
// on ExecutionTree to perform on IR tree.
|
// on ExecutionTree to perform on IR tree.
|
||||||
// Prepare the tree
|
// Prepare the tree
|
||||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
|
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true));
|
||||||
|
|
||||||
// After the tree is prepared, the col_name_id_map can safely be obtained
|
// After the tree is prepared, the col_name_id_map can safely be obtained
|
||||||
column_name_map_ = tree_->root()->column_name_id_map();
|
column_name_map_ = tree_->root()->column_name_id_map();
|
||||||
|
|
|
@ -44,7 +44,7 @@ Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
|
||||||
}
|
}
|
||||||
|
|
||||||
void NormalizeOp::Print(std::ostream &out) const {
|
void NormalizeOp::Print(std::ostream &out) const {
|
||||||
out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl;
|
out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl;
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) {
|
||||||
};
|
};
|
||||||
|
|
||||||
exe_tree->SetPrePassOverride(pass);
|
exe_tree->SetPrePassOverride(pass);
|
||||||
ASSERT_OK(exe_tree->PrepareTreePreAction());
|
ASSERT_OK(exe_tree->PreAction());
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
// print the tree in std::string as a way to verify that nodes are indeed removed
|
// print the tree in std::string as a way to verify that nodes are indeed removed
|
||||||
|
@ -124,7 +124,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
|
||||||
};
|
};
|
||||||
|
|
||||||
exe_tree->SetPrePassOverride(pass);
|
exe_tree->SetPrePassOverride(pass);
|
||||||
ASSERT_OK(exe_tree->PrepareTreePreAction());
|
ASSERT_OK(exe_tree->PreAction());
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
// print the tree in std::string as a way to verify that nodes are indeed removed
|
// print the tree in std::string as a way to verify that nodes are indeed removed
|
||||||
exe_tree->Print(ss);
|
exe_tree->Print(ss);
|
||||||
|
|
|
@ -237,7 +237,7 @@ def test_cache_map_failure1():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "Nested cache operations is not supported!" in str(e.value)
|
assert "Nested cache operations" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure1 Ended.\n')
|
logger.info('test_cache_failure1 Ended.\n')
|
||||||
|
@ -279,7 +279,7 @@ def test_cache_map_failure2():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in dsz.create_dict_iterator():
|
for _ in dsz.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure2 Ended.\n')
|
logger.info('test_cache_failure2 Ended.\n')
|
||||||
|
@ -319,7 +319,7 @@ def test_cache_map_failure3():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator():
|
for _ in ds1.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure3 Ended.\n')
|
logger.info('test_cache_failure3 Ended.\n')
|
||||||
|
@ -361,7 +361,7 @@ def test_cache_map_failure4():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator():
|
for _ in ds1.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure4 Ended.\n')
|
logger.info('test_cache_failure4 Ended.\n')
|
||||||
|
@ -402,7 +402,7 @@ def test_cache_map_failure5():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data.create_dict_iterator():
|
for _ in data.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value)
|
assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure5 Ended.\n')
|
logger.info('test_cache_failure5 Ended.\n')
|
||||||
|
@ -522,7 +522,7 @@ def test_cache_map_failure8():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value)
|
assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure8 Ended.\n')
|
logger.info('test_cache_failure8 Ended.\n')
|
||||||
|
@ -564,7 +564,7 @@ def test_cache_map_failure9():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator():
|
for _ in ds1.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure9 Ended.\n')
|
logger.info('test_cache_failure9 Ended.\n')
|
||||||
|
@ -606,7 +606,7 @@ def test_cache_map_failure10():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator():
|
for _ in ds1.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "SkipOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
logger.info('test_cache_failure10 Ended.\n')
|
logger.info('test_cache_failure10 Ended.\n')
|
||||||
|
@ -655,13 +655,13 @@ def test_cache_map_split1():
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds1.create_dict_iterator():
|
for _ in ds1.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as e:
|
with pytest.raises(RuntimeError) as e:
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in ds2.create_dict_iterator():
|
for _ in ds2.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value)
|
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||||
logger.info('test_cache_split1 Ended.\n')
|
logger.info('test_cache_split1 Ended.\n')
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue