forked from OSSInnovation/mindspore
!2772 add a pre pass for node removals
Merge pull request !2772 from Jamie/removalpass
This commit is contained in:
commit
efe07bd169
|
@ -409,7 +409,7 @@ Status BatchOp::UnpackPadInfo(const PadInfo &pad_info,
|
|||
// Visitor accept method for NodePass
|
||||
Status BatchOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<BatchOp>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -111,6 +111,51 @@ void DatasetOp::RemoveParent(const DatasetOp *parent) {
|
|||
parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end());
|
||||
}
|
||||
|
||||
// Removes this node from the tree and connects it's parent/child together
|
||||
Status DatasetOp::Remove() {
|
||||
if (parent_.size() > 1) {
|
||||
std::string err_msg("No support for op removal if the operator has more than one parent");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
if (child_.size() > 1) {
|
||||
std::string err_msg("No support for op removal if the operator has more than one child");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Scenario's when removing node B:
|
||||
// A -> B -> C
|
||||
// A -> B
|
||||
// B -> C
|
||||
//
|
||||
// If we remove B, then first take our child A and update it's parent to be C
|
||||
// It's possible the parent is null if we are the root node being removed.
|
||||
if (!child_.empty()) {
|
||||
// If we have a parent, then assign chlid's parent to point to our parent.
|
||||
if (!parent_.empty()) {
|
||||
child_[0]->parent_[0] = parent_[0];
|
||||
} else {
|
||||
// We don't have a parent, so we are the root node being removed.
|
||||
// clear the parent list of our child so that it becomes the new root.
|
||||
child_[0]->parent_.clear();
|
||||
tree_->AssignRoot(child_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Next, if we had a parent, then set it's child to be our child.
|
||||
if (!parent_.empty()) {
|
||||
// if we have a child, then set our parent to point to it
|
||||
if (!child_.empty()) {
|
||||
parent_[0]->child_[0] = child_[0];
|
||||
} else {
|
||||
// We don't have a child, so clear the child list of the current
|
||||
// parent because it will be empty once we are removed.
|
||||
parent_[0]->child_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Getter function to get a shared pointer to our childAdds a operator to become our child.
|
||||
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
|
||||
MS_ASSERT(child_index < static_cast<int>(child_.size()));
|
||||
|
@ -289,6 +334,12 @@ Status DatasetOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// DatasetOp is the base class of visitor target pre-visit.
|
||||
// This method will only be called if its derived class does not implement one.
|
||||
return p->PreRunOnNode(shared_from_this(), modified);
|
||||
}
|
||||
|
||||
Status DatasetOp::Accept(NodePass *p, bool *modified) {
|
||||
// DatasetOp is the base class of visitor target.
|
||||
// This method will only be called if its derived class does not implement one.
|
||||
|
|
|
@ -71,6 +71,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @param child - shared pointer to the child to remove.
|
||||
Status RemoveChild(std::shared_ptr<DatasetOp> child);
|
||||
|
||||
/// \brief Removes this node from the tree and connects it's parent/child together.
|
||||
/// \return Status eerror code returned
|
||||
Status Remove();
|
||||
|
||||
// Getter function to get a shared pointer to our child
|
||||
// @param child_index - An operator can have n children. Indicates choose which child to return.
|
||||
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
|
||||
|
@ -264,10 +268,20 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @return Vector of Children
|
||||
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
|
||||
|
||||
// Base method for NodePass visit.
|
||||
// Subclass needs to override this if it requires special node visit access.
|
||||
// Check "dataset/engine/opt/pass.h" for more details.
|
||||
// @return Statue of the node visit
|
||||
/// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up
|
||||
/// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main
|
||||
/// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it
|
||||
/// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details.
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
virtual Status PreAccept(NodePass *p, bool *modified);
|
||||
|
||||
/// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access.
|
||||
/// Check "dataset/engine/opt/pass.h" for more details.
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
virtual Status Accept(NodePass *p, bool *modified);
|
||||
|
||||
// Op name getter
|
||||
|
@ -285,6 +299,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// Computes a CRC value for the operator
|
||||
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
|
||||
|
||||
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
|
||||
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
|
||||
/// \return A shared_ptr casted to the derived class
|
||||
template <typename Derived>
|
||||
std::shared_ptr<Derived> shared_from_base() {
|
||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||
}
|
||||
|
||||
protected:
|
||||
// Adds a parent operator to this operator
|
||||
// @notes External callers do not have access to this function.
|
||||
|
|
|
@ -313,7 +313,7 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Visitor accept method for NodePass
|
||||
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<DeviceQueueOp>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -261,7 +261,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
// Visitor accept method for NodePass
|
||||
Status FilterOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<FilterOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -367,7 +367,7 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
|
|||
// Visitor accept method for NodePass
|
||||
Status MapOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -131,7 +131,7 @@ Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
|
|||
// Visitor accept method for NodePass
|
||||
Status ProjectOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<ProjectOp>(), modified);
|
||||
}
|
||||
|
||||
// Compute the column map and save it into our own column name map
|
||||
|
|
|
@ -176,7 +176,7 @@ Status RenameOp::EoeReceived(int32_t) {
|
|||
// Visitor accept method for NodePass
|
||||
Status RenameOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<RenameOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -190,7 +190,7 @@ int32_t RepeatOp::num_producers() const {
|
|||
// Visitor accept method for NodePass
|
||||
Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -298,7 +298,7 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
|
|||
// Visitor accept method for NodePass
|
||||
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<ShuffleOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -130,7 +130,7 @@ Status SkipOp::EofReceived(int32_t worker_id) {
|
|||
// Visitor accept method for NodePass
|
||||
Status SkipOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<SkipOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -249,7 +249,7 @@ Status GeneratorOp::Reset() {
|
|||
// Visitor accept method for NodePass
|
||||
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<GeneratorOp>(), modified);
|
||||
}
|
||||
|
||||
Status GeneratorOp::ComputeColMap() {
|
||||
|
|
|
@ -411,7 +411,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se
|
|||
// Visitor accept method for NodePass
|
||||
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<ImageFolderOp>(), modified);
|
||||
}
|
||||
|
||||
Status ImageFolderOp::ComputeColMap() {
|
||||
|
|
|
@ -496,7 +496,7 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
|
|||
// Visitor accept method for NodePass
|
||||
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<MindRecordOp>(), modified);
|
||||
}
|
||||
|
||||
Status MindRecordOp::ComputeColMap() {
|
||||
|
|
|
@ -1004,7 +1004,7 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
|
|||
// Visitor accept method for NodePass
|
||||
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<TFReaderOp>(), modified);
|
||||
}
|
||||
|
||||
Status TFReaderOp::ComputeColMap() {
|
||||
|
|
|
@ -136,7 +136,7 @@ Status TakeOp::PrepareNodePostAction() {
|
|||
// Visitor accept method for NodePass
|
||||
Status TakeOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<TakeOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -237,7 +237,7 @@ Status ZipOp::EoeReceived(int32_t) {
|
|||
// Visitor accept method for NodePass
|
||||
Status ZipOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
|
||||
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
|
||||
}
|
||||
|
||||
Status ZipOp::ComputeColMap() {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/perf/profiling.h"
|
||||
#include "dataset/engine/perf/monitor.h"
|
||||
|
||||
|
@ -214,7 +215,8 @@ Status ExecutionTree::PrepareTreePreAction() {
|
|||
bool modified = false;
|
||||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||
// Construct pre actions
|
||||
// example: pre_actions.push_back(new SomePass());
|
||||
MS_LOG(INFO) << "Running pre pass";
|
||||
pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass()));
|
||||
// Apply pre action passes
|
||||
for (auto &pass : pre_actions) {
|
||||
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
util/printer_pass.cc
|
||||
pass.cc
|
||||
pre/removal_nodes.cc
|
||||
pre/removal_pass.cc
|
||||
util/printer_pass.cc
|
||||
)
|
||||
|
|
|
@ -61,6 +61,7 @@ Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
|||
|
||||
// Helper function to perform DFS visit
|
||||
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||
RETURN_IF_NOT_OK(node->PreAccept(this, modified));
|
||||
for (const auto &c : node->Children()) {
|
||||
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
|
||||
}
|
||||
|
@ -159,6 +160,5 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
|
|||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,14 +66,16 @@ class Pass : public std::enable_shared_from_this<Pass> {
|
|||
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
|
||||
class TreePass : public Pass {
|
||||
public:
|
||||
// Run the transformation pass against the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
/// \brief Run the transformation pass against the execution tree.
|
||||
/// \param[inout] tree Pointer to the execution tree to be transformed.
|
||||
/// \param[inout] modified Indicate if the tree was modified
|
||||
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||
|
||||
// Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
// @return Status - The error code return
|
||||
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||
};
|
||||
|
||||
|
@ -90,14 +92,23 @@ class NodePass : public Pass {
|
|||
|
||||
~NodePass() = default;
|
||||
|
||||
// Run the transformation pass against the execution tree.
|
||||
// @param tree - Pointer to the execution tree to be transformed.
|
||||
// @param modified - Pointer to the modified flag,
|
||||
/// \brief Run the transformation pass against the execution tree
|
||||
/// \param[inout] tree Pointer to the execution tree to be transformed
|
||||
/// \param[inout] modified Indicator if the tree was changed
|
||||
Status Run(ExecutionTree *tree, bool *modified) final;
|
||||
|
||||
// Derived classes may implement the runOnNode function to implement node level tree transformation.
|
||||
// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
// @return Status - The error code return
|
||||
/// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down
|
||||
/// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
/// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all.
|
||||
/// \return Status The error code return
|
||||
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
// Visit methods to be overridden.
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||
if (removal_pass_) {
|
||||
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
||||
} else {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RemovalPass;
|
||||
|
||||
/// \class RemovalNodes removal_nodes.h
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
/// It works in conjunction with the removal_pass.
|
||||
class RemovalNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||
explicit RemovalNodes(RemovalPass *removal_pass);
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
bool is_caching_;
|
||||
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
RemovalPass::RemovalPass() {}
|
||||
|
||||
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
// Create the removal node pass which can identify which nodes need to be removed.
|
||||
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
||||
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||
|
||||
// Then, execute the removal of any nodes that were set up for removal
|
||||
for (auto node : removal_nodes_) {
|
||||
node->Remove();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the list of operators to be removed
|
||||
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class RemovalPass removal_pass.h
|
||||
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||
/// nodes should be removed, and then removes them.
|
||||
class RemovalPass : public TreePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RemovalPass();
|
||||
|
||||
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
/// \brief Adds an operator to the list of operators to be removed
|
||||
/// \param[in] dataset_op The operator to add to the removal list
|
||||
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
Loading…
Reference in New Issue