forked from mindspore-Ecosystem/mindspore
!9564 Tidy up code in dataset compilation phase
From: @nsyca Reviewed-by: Signed-off-by:
This commit is contained in:
commit
49fd5308a4
|
@ -556,8 +556,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
|||
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
|
||||
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) {
|
||||
auto vocab = std::make_shared<SentencePieceVocab>();
|
||||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
|
||||
character_coverage, model_type, params);
|
||||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
|
||||
model_type, params);
|
||||
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
|
@ -588,8 +588,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
|||
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
|
||||
const std::vector<std::string> &special_tokens, bool special_first) {
|
||||
auto vocab = std::make_shared<Vocab>();
|
||||
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
|
||||
special_first);
|
||||
auto ds =
|
||||
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
|
||||
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
|
|
|
@ -206,9 +206,6 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Getter function to get all of our children.
|
||||
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
|
||||
|
||||
// Getter function to get all of our parents.
|
||||
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
|
||||
|
||||
|
|
|
@ -111,9 +111,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
|
||||
void Parent(DatasetOp **parent, int32_t parent_index) const;
|
||||
|
||||
// Getter function to get all of our children.
|
||||
std::vector<std::shared_ptr<DatasetOp>> children() const;
|
||||
|
||||
// Getter function to get all of our parents.
|
||||
std::vector<DatasetOp *> parents() const;
|
||||
|
||||
|
|
|
@ -233,40 +233,14 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
|
|||
return shared_from_this();
|
||||
}
|
||||
|
||||
DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) {
|
||||
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
|
||||
// Fetch some default value from config manager
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
}
|
||||
|
||||
const bool DatasetNode::IsTree() const {
|
||||
bool is_tree = true;
|
||||
if (this->parent_.size() > 1) {
|
||||
MS_LOG(WARNING) << Name() << " has more than one parent.";
|
||||
return false;
|
||||
}
|
||||
for (const auto &child : children_) {
|
||||
is_tree = child->IsTree();
|
||||
if (!is_tree) {
|
||||
MS_LOG(WARNING) << Name() << " has more than one parent.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
return is_tree;
|
||||
}
|
||||
|
||||
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
|
||||
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
|
||||
std::shared_ptr<DatasetNode> new_node = this->Copy();
|
||||
// temporary fix to set the num_workers to the new node.
|
||||
new_node->SetNumWorkers(this->num_workers_);
|
||||
for (const auto &child : children_) {
|
||||
new_node->AddChild(child->DeepCopy());
|
||||
}
|
||||
return new_node;
|
||||
mappable_ = kNotADataSource;
|
||||
}
|
||||
|
||||
std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
|
||||
|
@ -310,54 +284,105 @@ void DatasetNode::PrintNode(std::ostream &out, int *level) const {
|
|||
}
|
||||
|
||||
// Add a node as a child, node's parent needs to be empty
|
||||
// this function will allow child to be a nullptr, in which case it will simply skip
|
||||
// This function will allow child to be a nullptr, in which case it will simply skip.
|
||||
// This function is used only when building IR node one by one from parsing the user code.
|
||||
// During the parsing, we allow a node to have more than one parent, possibly forming a graph.
|
||||
// It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
|
||||
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
|
||||
if (child != nullptr && child->parent_.empty()) {
|
||||
if (child != nullptr) {
|
||||
children_.push_back(child);
|
||||
child->parent_.push_back(this);
|
||||
} else if (child != nullptr) {
|
||||
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
|
||||
children_.push_back(child);
|
||||
child->parent_.push_back(this);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert a node as a child of this node. This node's children becomes the children of the inserted node.
|
||||
// Add the input node to be the next child of this node
|
||||
// This function is used in doing a deep copy of the IR tree built by parsing the user code.
|
||||
// This function assumes we walk the tree in DFS left-to-right.
|
||||
// This is a temporary function to be replaced later by a set of better tree operations.
|
||||
void DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
|
||||
if (child != nullptr) {
|
||||
if (child->parent_ != nullptr) {
|
||||
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
|
||||
}
|
||||
children_.push_back(child);
|
||||
child->parent_ = this;
|
||||
}
|
||||
}
|
||||
|
||||
// Add a node as a parent, node's parent needs to be empty (future use)
|
||||
Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
|
||||
|
||||
if (node->parent_ != nullptr) {
|
||||
DatasetNode *parent = node->parent_;
|
||||
for (auto i = parent->children_.size() - 1; i >= 0; --i) {
|
||||
if (parent->children_[i] == node) {
|
||||
parent->children_[i] = static_cast<std::shared_ptr<DatasetNode>>(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
children_.push_back(node);
|
||||
node->parent_ = this;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Insert a node as a child of this node
|
||||
// This node's children become the children of the inserted node.
|
||||
Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent.");
|
||||
|
||||
for (auto child : children_) {
|
||||
node->children_.push_back(child);
|
||||
child->parent_.clear();
|
||||
child->parent_.push_back(node.get());
|
||||
child->parent_ = node.get();
|
||||
}
|
||||
// Then establish the new parent-child relationship with the new parent.
|
||||
children_.clear();
|
||||
children_.push_back(node);
|
||||
node->parent_.clear();
|
||||
node->parent_.push_back(this);
|
||||
node->parent_ = this;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Insert a node as a child next to this node (future use)
|
||||
Status DatasetNode::InsertAfter(std::shared_ptr<DatasetNode> node) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must have a parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent.");
|
||||
auto size = parent_->children_.size();
|
||||
// Duplicate the last child to increase the size by 1
|
||||
parent_->children_.push_back(parent_->children_[size - 1]);
|
||||
// Shift each child to its right until we found the insertion point, then insert the input node
|
||||
bool found = false;
|
||||
for (auto i = parent_->children_.size() - 2; i >= 0; --i) {
|
||||
if (parent_->children_[i].get() != this) {
|
||||
parent_->children_[i + 1] = parent_->children_[i];
|
||||
} else {
|
||||
parent_->children_[i + 1] = node;
|
||||
node->parent_ = parent_;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!found, "Insertion point not found.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Remove this node from its parent. Add the child of this node to its parent.
|
||||
// for now, this remove is limited to node with a single child or no child
|
||||
Status DatasetNode::Remove() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(parent_.size() != 0, "Cannot remove root or a node without parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
|
||||
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
|
||||
parent_[0]->children_.erase(
|
||||
std::remove(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()),
|
||||
parent_[0]->children_.end()); // removal using "erase remove idiom"
|
||||
} else { // replace my position in my parent's children list with my single child
|
||||
auto itr = std::find(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_[0]->children_.end(), "I am not in my parent's children list.");
|
||||
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
|
||||
parent_->children_.end()); // removal using "erase remove idiom"
|
||||
} else { // replace my position in my parent's children list with my single child
|
||||
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
|
||||
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
|
||||
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
|
||||
children_.clear(); // release my single child from my children list
|
||||
children_.clear(); // release my single child from my children list
|
||||
}
|
||||
parent_[0] = nullptr;
|
||||
parent_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -146,10 +146,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
return out;
|
||||
}
|
||||
|
||||
/// \brief Make a new copy of the tree from the current node
|
||||
/// \return The new copy of the tree
|
||||
std::shared_ptr<DatasetNode> DeepCopy();
|
||||
|
||||
/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
|
@ -175,43 +171,61 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Child nodes
|
||||
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
|
||||
|
||||
/// \brief Getter function for parents nodes
|
||||
/// \return Parent nodes
|
||||
const std::vector<DatasetNode *> Parent() const { return parent_; }
|
||||
/// \brief Getter function for the parent node
|
||||
/// \return The parent node (of a node from a cloned IR tree)
|
||||
DatasetNode *Parent() const { return parent_; }
|
||||
|
||||
/// \brief Establish the parent-child relationship between this node and its child.
|
||||
/// \brief Establish a parent-child relationship between this node and the input node.
|
||||
/// Used when building the IR tree.
|
||||
void AddChild(std::shared_ptr<DatasetNode> child);
|
||||
|
||||
/// \brief Establish a parent-child relationship between this node and the input node.
|
||||
/// Used during the cloning of the user-input IR tree (temporary use)
|
||||
void AppendChild(std::shared_ptr<DatasetNode> child);
|
||||
|
||||
/// \brief Establish the child-parent relationship between this node and the input node (future use)
|
||||
Status InsertAbove(std::shared_ptr<DatasetNode> node);
|
||||
|
||||
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
|
||||
Status InsertBelow(std::shared_ptr<DatasetNode> node);
|
||||
|
||||
/// \brief Add the input node as the next sibling (future use)
|
||||
Status InsertAfter(std::shared_ptr<DatasetNode> node);
|
||||
|
||||
/// \brief detach this node from its parent, add its child (if any) to its parent
|
||||
/// \return error code, return error if node has more than 1 children
|
||||
Status Remove();
|
||||
|
||||
/// \brief Check if this node has cache
|
||||
/// \brief Check if this node has cache
|
||||
/// \return True if the data of this node will be cached
|
||||
const bool IsCached() const { return (cache_ != nullptr); }
|
||||
|
||||
/// \brief Check if this node is a tree
|
||||
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
|
||||
const bool IsTree() const;
|
||||
|
||||
/// \brief Check if this node is a leaf node.
|
||||
/// \brief Check if this node is a leaf node.
|
||||
/// \return True if this is a leaf node.
|
||||
const bool IsLeaf() const { return children_.empty(); }
|
||||
|
||||
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
||||
/// \return True if the dataset represented by this node is a mappable dataset
|
||||
const bool IsMappable() const { return mappable_; }
|
||||
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
||||
/// \return True if this node is a mappable dataset
|
||||
const bool IsMappable() const { return (mappable_ == kMappableSource); }
|
||||
|
||||
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||
/// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes
|
||||
/// \return True if this node is a non-mappable dataset
|
||||
const bool IsNonMappable() const { return (mappable_ == kNonMappableSource); }
|
||||
|
||||
/// \brief Check if this node is not a data source node.
|
||||
/// \return True if this node is not a data source node
|
||||
const bool IsNotADataSource() const { return (mappable_ == kNotADataSource); }
|
||||
|
||||
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||
/// \return True if a cache-enabled operator is an ancestor of this node
|
||||
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
|
||||
|
||||
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||
|
||||
/// \brief Getter of the number of workers
|
||||
int32_t num_workers() { return num_workers_; }
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
@ -247,7 +261,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children_;
|
||||
std::vector<DatasetNode *> parent_;
|
||||
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
|
||||
std::shared_ptr<DatasetCache> cache_;
|
||||
int64_t dataset_size_ = -1;
|
||||
int32_t num_workers_;
|
||||
|
@ -257,7 +271,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
||||
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
||||
void PrintNode(std::ostream &out, int *level) const;
|
||||
bool mappable_;
|
||||
enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 };
|
||||
enum DataSource mappable_;
|
||||
bool descendant_of_cache_;
|
||||
};
|
||||
|
||||
|
@ -265,12 +280,12 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
class MappableSourceNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MappableSourceNode() : DatasetNode() { mappable_ = true; }
|
||||
MappableSourceNode() : DatasetNode() { mappable_ = kMappableSource; }
|
||||
|
||||
/// \brief Constructor that initializes the cache
|
||||
/// \param dataset_cache DatasetCache
|
||||
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
||||
mappable_ = true;
|
||||
mappable_ = kMappableSource;
|
||||
// Initially set to false, and set to true by the optimizer when conditions are met.
|
||||
descendant_of_cache_ = false;
|
||||
}
|
||||
|
@ -287,12 +302,12 @@ class MappableSourceNode : public DatasetNode {
|
|||
class NonMappableSourceNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
|
||||
NonMappableSourceNode() : DatasetNode() { mappable_ = kNonMappableSource; }
|
||||
|
||||
/// \brief Constructor that initializes the cache
|
||||
/// \param dataset_cache DatasetCache
|
||||
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
||||
mappable_ = false;
|
||||
mappable_ = kNonMappableSource;
|
||||
// Initially set to false, and set to true by the optimizer when conditions are met.
|
||||
descendant_of_cache_ = false;
|
||||
}
|
||||
|
|
|
@ -27,13 +27,14 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for RootNode
|
||||
RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) {
|
||||
// The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode)
|
||||
RootNode::RootNode(std::shared_ptr<DatasetNode> child) : DatasetNode() {
|
||||
// The root node's parent must remain nullptr, which is set in the constructor of DatasetNode.
|
||||
AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> RootNode::Copy() {
|
||||
auto node = std::make_shared<RootNode>(nullptr, num_epochs_);
|
||||
auto node = std::make_shared<RootNode>(nullptr);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -54,7 +55,7 @@ Status RootNode::ValidateParams() {
|
|||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (parent_.size() != 0) {
|
||||
if (parent_ != nullptr) {
|
||||
std::string err_msg = "Internal error: root node should not have a parent";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
|
|
|
@ -29,7 +29,10 @@ namespace dataset {
|
|||
class RootNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
||||
RootNode() : DatasetNode() {}
|
||||
|
||||
/// \brief Constructor
|
||||
explicit RootNode(std::shared_ptr<DatasetNode> child);
|
||||
|
||||
/// \brief Destructor
|
||||
~RootNode() = default;
|
||||
|
@ -54,6 +57,9 @@ class RootNode : public DatasetNode {
|
|||
/// \brief Getter of number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
|
||||
/// \brief Setter of number of epochs
|
||||
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
|
|
@ -7,6 +7,7 @@ add_library(engine-opt OBJECT
|
|||
pre/cache_error_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
pre/epoch_ctrl_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
pre/getter_pass.cc
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
DeepCopyPass::DeepCopyPass() {
|
||||
root_ = std::make_shared<RootNode>();
|
||||
parent_ = root_.get();
|
||||
}
|
||||
|
||||
Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||
*modified = true;
|
||||
// Do a nested-loop walk to check whether a node has the same child more than once.
|
||||
// This is an artificial restriction. We can support it since we will do a clone of the input tree in this pass.
|
||||
// Example: ds2 = ds1 + ds1;
|
||||
auto children = node->Children();
|
||||
if (children.size() > 0) {
|
||||
for (auto it1 = children.begin(); it1 != children.end() - 1; ++it1) {
|
||||
for (auto it2 = it1 + 1; it2 != children.end(); ++it2) {
|
||||
if (*it1 == *it2) {
|
||||
std::string err_msg = "The same node " + (*it1)->Name() + " is a child of its parent more than once.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone a new copy of this node
|
||||
std::shared_ptr<DatasetNode> new_node = node->Copy();
|
||||
// Temporary fix to set the num_workers to each cloned node.
|
||||
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
|
||||
// the cloned node. Each derived class's Copy() will need to include this method.
|
||||
new_node->SetNumWorkers(node->num_workers());
|
||||
// This method below assumes a DFS walk and from the first child to the last child.
|
||||
// Future: A more robust implementation that does not depend on the above assumption.
|
||||
parent_->AppendChild(new_node);
|
||||
|
||||
// Then set this node to be a new parent to accept a copy of its next child
|
||||
parent_ = new_node.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeepCopyPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||
*modified = true;
|
||||
// After visit the node, move up to its parent
|
||||
parent_ = parent_->Parent();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class DeepCopyPass
|
||||
/// \brief This pass clones a new copy of IR tree. A new copy is used in the compilation to avoid any modification to
|
||||
/// the IR tree associated with the user code.
|
||||
class DeepCopyPass : public IRNodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
DeepCopyPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~DeepCopyPass() = default;
|
||||
|
||||
/// \brief Clone a new copy of the node
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] *modified indicates whether the node has been visited
|
||||
/// \return Status code
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Reset parent after walking its sub tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] *modified indicates whether the node has been visited
|
||||
/// \return Status code
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter method to retrieve the root node
|
||||
/// \return the root node of the new cloned tree
|
||||
std::shared_ptr<RootNode> Root() { return root_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<RootNode> root_;
|
||||
DatasetNode *parent_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||
|
@ -24,6 +25,19 @@ namespace dataset {
|
|||
Status InputValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||
*modified = false;
|
||||
RETURN_IF_NOT_OK(node->ValidateParams());
|
||||
|
||||
// A data source node must be a leaf node
|
||||
if ((node->IsMappable() || node->IsNonMappable()) && !node->IsLeaf()) {
|
||||
std::string err_msg = node->Name() + " is a data source and must be a leaf node.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// A non-leaf node must not be a data source node
|
||||
if (node->IsNotADataSource() && node->IsLeaf()) {
|
||||
std::string err_msg = node->Name() + " is a dataset operator and must not be a leaf node.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||
|
@ -27,6 +28,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
TreeAdapter::TreeAdapter() {
|
||||
tree_state_ = kCompileStateInit;
|
||||
optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false;
|
||||
}
|
||||
|
||||
Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
||||
// Vector of actions in pre-pass phase
|
||||
std::vector<std::unique_ptr<IRPass>> actions;
|
||||
|
@ -86,7 +92,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
||||
Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
||||
// Build the DatasetOp ExecutionTree from the optimized IR tree
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(ir->Build(&ops));
|
||||
|
@ -104,47 +110,20 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha
|
|||
// Build the children of IR, once they return, add the return value to *op
|
||||
for (std::shared_ptr<DatasetNode> child_ir : ir->Children()) {
|
||||
std::shared_ptr<DatasetOp> child_op;
|
||||
RETURN_IF_NOT_OK(BuildExecutionTree(child_ir, &child_op));
|
||||
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(child_ir, &child_op));
|
||||
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
|
||||
optimize_ = true; // Always ON (temporary)
|
||||
|
||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
||||
|
||||
// We will first walk the input tree to sanity check this is not a graph
|
||||
// Flag an error when it is not a tree
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)");
|
||||
|
||||
// Copy the input IR tree and insert under the root node
|
||||
// Create a root node to host the new copy of the input IR tree to pass to the optimizer
|
||||
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
|
||||
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
|
||||
|
||||
// Pre-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PrePass(root_ir));
|
||||
|
||||
// Optional phase of optimization
|
||||
if (optimize_) {
|
||||
RETURN_IF_NOT_OK(Optimize(root_ir));
|
||||
}
|
||||
|
||||
// Post-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PostPass(root_ir));
|
||||
|
||||
MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';
|
||||
|
||||
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
|
||||
// This will evolve in the long run
|
||||
tree_ = std::make_unique<ExecutionTree>();
|
||||
|
||||
// Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree
|
||||
std::shared_ptr<DatasetOp> root_op;
|
||||
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op));
|
||||
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
|
||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
||||
|
@ -165,6 +144,48 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
|
||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||
|
||||
tree_state_ = kCompileStateIRGraphBuilt;
|
||||
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
||||
|
||||
// Clone the input IR tree and insert under the root node
|
||||
// Create a root node to host the new copy of the input IR tree
|
||||
// This is done so that the compilation will process and modify the tree
|
||||
// without changing the tree associated with the user code.
|
||||
// The tree from the user code is permitted to form a graph where any node
|
||||
// is consumed by more than one parent. However, this cloning process here
|
||||
// will break the graph into a tree by copying each consumption of a node into a new copy.
|
||||
bool m = false;
|
||||
DeepCopyPass cloning_tree;
|
||||
RETURN_IF_NOT_OK(cloning_tree.Run(input_ir, &m));
|
||||
std::shared_ptr<RootNode> root_ir = cloning_tree.Root();
|
||||
root_ir->SetNumEpochs(num_epochs);
|
||||
|
||||
tree_state_ = kCompileStateIRTreeCloned;
|
||||
MS_LOG(INFO) << "Plan before optimization:" << '\n' << *root_ir << '\n';
|
||||
|
||||
// Pre-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PrePass(root_ir));
|
||||
|
||||
// Optional phase of optimization
|
||||
if (optimize_) {
|
||||
RETURN_IF_NOT_OK(Optimize(root_ir));
|
||||
}
|
||||
|
||||
// Post-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PostPass(root_ir));
|
||||
|
||||
tree_state_ = kCompileStateOptimized;
|
||||
MS_LOG(INFO) << "Plan after optimization:" << '\n' << *root_ir << '\n';
|
||||
|
||||
RETURN_IF_NOT_OK(Build(root_ir, num_epochs));
|
||||
tree_state_ = kCompileStateReady;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::GetNext(TensorRow *row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
|
|
|
@ -33,7 +33,7 @@ class DatasetNode;
|
|||
|
||||
class TreeAdapter {
|
||||
public:
|
||||
TreeAdapter() = default;
|
||||
TreeAdapter();
|
||||
|
||||
~TreeAdapter() = default;
|
||||
|
||||
|
@ -68,28 +68,40 @@ class TreeAdapter {
|
|||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
private:
|
||||
// This function runs a mandatory pass checking the syntax and semantics of the IR tree.
|
||||
// Run the mandatory pass checking the syntax and semantics of the IR tree
|
||||
Status PrePass(std::shared_ptr<DatasetNode> ir);
|
||||
|
||||
// This function runs an optional optimization pass on the IR tree.
|
||||
// Run the optional optimization pass on the IR tree
|
||||
Status Optimize(std::shared_ptr<DatasetNode> ir);
|
||||
|
||||
// This function runs a mandatory pass augmenting the IR tree before the execution.
|
||||
// Run the mandatory pass augmenting the IR tree
|
||||
Status PostPass(std::shared_ptr<DatasetNode> ir);
|
||||
|
||||
// Build an Execution tree
|
||||
Status Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs);
|
||||
|
||||
// This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree.
|
||||
Status BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
|
||||
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
|
||||
|
||||
std::unique_ptr<DataBuffer> cur_db_;
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
||||
int32_t num_epochs_;
|
||||
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
|
||||
int32_t cur_batch_num_; // current batch number, used for profiling
|
||||
int32_t cur_connector_size_; // current connector size of root op, used for profiling
|
||||
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
|
||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
||||
|
||||
// State flags for the lifecycle of the tree
|
||||
enum CompileState {
|
||||
kCompileStateInit = 0, // The freshly initialized state
|
||||
kCompileStateIRGraphBuilt, // User code has been parsed and its IR graph built
|
||||
kCompileStateIRTreeCloned, // IR tree has been cloned from the IR graph
|
||||
kCompileStateOptimized, // IR tree has been optimized
|
||||
kCompileStateReady // Execution tree is generated from the optimized IR
|
||||
};
|
||||
CompileState tree_state_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,7 +65,6 @@ SET(DE_UT_SRCS
|
|||
image_folder_op_test.cc
|
||||
image_process_test.cc
|
||||
interrupt_test.cc
|
||||
ir_node_test.cc
|
||||
jieba_tokenizer_op_test.cc
|
||||
main_test.cc
|
||||
map_op_test.cc
|
||||
|
|
|
@ -1,137 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestIRNodes : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestIRNodes() = default;
|
||||
// compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
|
||||
// if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
|
||||
Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
|
||||
if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
|
||||
std::string err_msg =
|
||||
"Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
size_t num_child = root1->Children().size();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
|
||||
root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
|
||||
std::to_string(root2->Children().size()) + " child.");
|
||||
|
||||
for (size_t ind = 0; ind < num_child; ind++) {
|
||||
RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// print the node's name in post order
|
||||
Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ir);
|
||||
for (auto child : ir->Children()) {
|
||||
RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
|
||||
}
|
||||
names += (ir->Name() + "->");
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";
|
||||
|
||||
auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();
|
||||
|
||||
auto tree2 = tree1->DeepCopy();
|
||||
std::string tree_1_names, tree_2_names;
|
||||
|
||||
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
|
||||
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
|
||||
|
||||
// expected output for the 2 names:
|
||||
// RandomDataset->Repeat->Project->Shuffle->Batch->
|
||||
EXPECT_EQ(tree_1_names, tree_2_names);
|
||||
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
|
||||
|
||||
// verify compare function is correct
|
||||
EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";
|
||||
|
||||
auto branch1 = RandomData(44)->Project({"label"});
|
||||
auto branch2 = RandomData(44)->Shuffle(10);
|
||||
|
||||
auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();
|
||||
|
||||
auto tree2 = tree1->DeepCopy();
|
||||
std::string tree_1_names, tree_2_names;
|
||||
|
||||
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
|
||||
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
|
||||
|
||||
// expected output for the 2 names:
|
||||
// RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
|
||||
EXPECT_EQ(tree_1_names, tree_2_names);
|
||||
|
||||
// verify the pointer within the same tree are the same
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
|
||||
// verify two trees
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";
|
||||
|
||||
auto branch1 = RandomData(44)->Project({"label"});
|
||||
auto branch2 = ImageFolder("path");
|
||||
auto tree = Zip({branch1, branch2})->IRNode();
|
||||
/***
|
||||
tree looks like this, we will remove node and test its functionalities
|
||||
Zip
|
||||
/ \
|
||||
Project ImageFolder
|
||||
/
|
||||
RandomData
|
||||
***/
|
||||
auto tree_copy_1 = tree->DeepCopy();
|
||||
ASSERT_EQ(tree_copy_1->Children().size(), 2);
|
||||
// remove the project in the tree and test
|
||||
ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
|
||||
ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
|
||||
// remove the ImageFolder, a leaf node from the tree
|
||||
std::string tree_1_names, tree_2_names;
|
||||
ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
|
||||
EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
|
||||
auto tree_copy_2 = tree->DeepCopy();
|
||||
ASSERT_EQ(tree_copy_2->Children().size(), 2);
|
||||
tree_copy_2->Children()[1]->Remove();
|
||||
ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
|
||||
EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
|
||||
}
|
Loading…
Reference in New Issue