[MD] skip node pushdown optimization pass for Reset

This commit is contained in:
mohammad 2022-03-08 14:17:18 -05:00
parent 2655d64720
commit 260cebf650
23 changed files with 1351 additions and 21 deletions

View File

@ -7,6 +7,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
random_sampler.cc
sampler.cc
sequential_sampler.cc
skip_first_epoch_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -66,7 +66,7 @@ class SequentialSamplerRT : public SamplerRT {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
private:
protected:
int64_t current_id_; // The id sequencer. Each new id increments from this
int64_t start_index_; // The starting id. current_id_ begins from here.
int64_t id_count_; // An internal counter that tracks how many ids have been produced

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h"
#include <string>
namespace mindspore {
namespace dataset {
Status SkipFirstEpochSamplerRT::ResetSampler() {
if (id_count_ != num_samples_) {
std::string err_msg =
"[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) +
" num_samples_: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
current_id_ = 0;
id_count_ = 0;
if (!first_epoch_done_) {
num_samples_ += start_index_;
start_index_ = 0;
samples_per_tensor_ = num_samples_;
first_epoch_done_ = true;
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
}
return Status::OK();
}
int64_t SkipFirstEpochSamplerRT::CalculateNumSamples(const int64_t num_rows) { return -1; }
void SkipFirstEpochSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
out << "\nSampler: SkipFirstEpochSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
SamplerRT::SamplerPrint(out, show_all);
// Then add our own info
out << "\nStart index: " << start_index_;
out << "\nFirst epoch done: " << first_epoch_done_;
out << "\nCurrent id: " << current_id_;
out << "\nid count:" << id_count_;
}
}
Status SkipFirstEpochSamplerRT::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "SkipFirstEpochSampler";
args["start_index"] = start_index_;
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
namespace mindspore {
namespace dataset {
class SkipFirstEpochSamplerRT : public SequentialSamplerRT {
public:
// Constructor
using SequentialSamplerRT::SequentialSamplerRT;
// Destructor.
~SkipFirstEpochSamplerRT() = default;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Gets the number of samples available
/// \note Since this sampler returns different number of samples in the first epoch (compared to other epochs), this
/// function always returns -1
/// \param[in] num_rows The total number of rows in the dataset
/// \return int64_t Calculated number of samples (always -1)
int64_t CalculateNumSamples(int64_t num_rows) override;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void SamplerPrint(std::ostream &out, bool show_all) const override;
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
private:
bool first_epoch_done_ = false;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,8 +21,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -73,5 +74,18 @@ Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNo
*result = std::make_shared<ProjectNode>(ds, columns);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status ProjectNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ProjectNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status ProjectNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ProjectNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -70,6 +70,18 @@ class ProjectNode : public DatasetNode {
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
std::vector<std::string> columns_;
};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,8 +21,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -82,5 +83,17 @@ Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNod
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status RenameNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RenameNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status RenameNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RenameNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,18 @@ class RenameNode : public DatasetNode {
/// \return Name of the current node
std::string Name() const override { return kRenameNode; }
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -99,6 +99,9 @@ class SkipNode : public DatasetNode {
void SetFirstEpochOnly(bool flag) { first_epoch_only_ = flag; }
/// \brief Getter functions
const bool FirstEpochOnly() const { return first_epoch_only_; }
private:
int32_t skip_count_;
bool first_epoch_only_ = false;

View File

@ -8,6 +8,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES
random_sampler_ir.cc
samplers_ir.cc
sequential_sampler_ir.cc
skip_first_epoch_sampler_ir.cc
subset_random_sampler_ir.cc
subset_sampler_ir.cc
weighted_random_sampler_ir.cc

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,7 +65,7 @@ class SequentialSamplerObj : public SamplerObj {
Status ValidateParams() override;
private:
protected:
int64_t start_index_;
int64_t num_samples_;
};

View File

@ -0,0 +1,68 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h"
#include "minddata/dataset/util/validators.h"
namespace mindspore {
namespace dataset {
// Constructor
SkipFirstEpochSamplerObj::SkipFirstEpochSamplerObj(int64_t start_index) : SequentialSamplerObj(start_index, 0) {}
// Destructor
SkipFirstEpochSamplerObj::~SkipFirstEpochSamplerObj() = default;
Status SkipFirstEpochSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
args["sampler_name"] = "SkipFirstEpochSamplerObj";
args["start_index"] = start_index_;
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status SkipFirstEpochSamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "start_index", "SkipFirstEpochSamplerObj"));
int64_t start_index = json_obj["start_index"];
*sampler = std::make_shared<SkipFirstEpochSamplerObj>(start_index);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
Status SkipFirstEpochSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
*sampler = std::make_shared<dataset::SkipFirstEpochSamplerRT>(start_index_, 0);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
std::shared_ptr<SamplerObj> SkipFirstEpochSamplerObj::SamplerCopy() {
auto sampler = std::make_shared<SkipFirstEpochSamplerObj>(start_index_);
for (const auto &child : children_) {
Status rc = sampler->AddChildSampler(child);
if (rc.IsError()) {
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
}
}
return sampler;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_
#include <memory>
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
#include "include/api/status.h"
namespace mindspore {
namespace dataset {
// Internal Sampler class forward declaration
class SamplerRT;
class SkipFirstEpochSamplerObj : public SequentialSamplerObj {
public:
explicit SkipFirstEpochSamplerObj(int64_t start_index);
~SkipFirstEpochSamplerObj() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override;
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
#endif
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_

View File

@ -15,6 +15,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES
pre/input_validation_pass.cc
pre/node_offload_pass.cc
pre/node_removal_pass.cc
pre/skip_pushdown_pass.cc
)
if(ENABLE_PYTHON)

View File

@ -0,0 +1,177 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
namespace mindspore {
namespace dataset {
SkipPushdownPass::SkipNodes::SkipNodes() : skip_count_(0) {}
// activate the optimization steps, and increase skip_count_ (if not the first skip node in the pipeline)
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
if (node->FirstEpochOnly() == false) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
skip_count_ += node->Count();
nodes_to_remove_.push_back(node);
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) {
if (node->FirstEpochOnly() == false) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ == 0, "The skip_count_ cannot be non-zero here.");
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
#ifdef ENABLE_PYTHON
if (node->BatchSizeFunc()) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// we have an active skip node above.
skip_count_ *= node->BatchSize();
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// we have an active skip node above.
auto new_sampler = std::make_shared<SkipFirstEpochSamplerObj>(skip_count_);
MS_LOG(INFO) << "Adding SkipFirstEpochSampler(" << skip_count_ << ")";
auto sampler = node->Sampler();
if (sampler != nullptr) {
new_sampler->AddChildSampler(sampler);
}
node->SetSampler(new_sampler);
skip_count_ = 0;
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// we have an active skip node above.
MS_LOG(WARNING)
<< "Pushing down skip node below a map node will result in slightly different outputs for random transformations.";
return Status::OK();
}
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// insert a skip node above
insert_skip_above_.emplace_back(node, skip_count_);
skip_count_ = 0;
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Since MindDataset requires its own SkipFirstEpochSampler (which is not implemented) we insert the skip node above it.
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// insert a skip node above
insert_skip_above_.emplace_back(node, skip_count_);
skip_count_ = 0;
return Status::OK();
}
#endif
// This functions is used for Ops that are random, and the ones in which Visit is Not Implemented yet;
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
// insert a skip node above
insert_skip_above_.emplace_back(node, skip_count_);
skip_count_ = 0;
return Status::OK();
}
// constructor
SkipPushdownPass::SkipPushdownPass() {}
// Walk the tree to push down the skip node inserted when Reset is called.
Status SkipPushdownPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
MS_LOG(INFO) << "Pre pass: skip node pushdown pass started.";
// Assumption: The total skip counts in the first_epoch_only skip node is less than the size of the dataset. This
// assumption is not validated here.
// Create the skip node pass which can identify which nodes need to be removed and which ones added.
std::unique_ptr<SkipPushdownPass::SkipNodes> skip_nodes = std::make_unique<SkipPushdownPass::SkipNodes>();
if (root_ir->IsSizeDefined()) {
RETURN_IF_NOT_OK(skip_nodes->Run(root_ir, modified));
}
// Update modified flag if there were any nodes identified to be removed
if (skip_nodes->nodes_to_remove().empty() == false || skip_nodes->insert_skip_above().empty() == false) {
*modified = true;
}
// Add skip node(s) to the tree (if any)
for (auto iter : skip_nodes->insert_skip_above()) {
MS_LOG(INFO) << "Inserting a Skip(" << iter.second << ") node above this node: " << iter.first->Name();
auto new_skip_node = std::make_shared<SkipNode>(iter.second);
new_skip_node->SetFirstEpochOnly(true);
RETURN_IF_NOT_OK(iter.first->InsertAbove(new_skip_node));
}
// Then, execute the removal of any nodes that were set up for removal
for (auto node : skip_nodes->nodes_to_remove()) {
RETURN_IF_NOT_OK(node->Drop());
}
MS_LOG(INFO) << "Pre pass: skip node pushdown pass is complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,156 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class BatchNode;
class DatasetNode;
class DatasetOp;
class MappableSourceNode;
class MapNode;
#ifndef ENABLE_ANDROID
class MindDataNode;
#endif
class NonMappableSourceNode;
class ProjectNode;
class RenameNode;
class SkipNode;
/// \class SkipPushdownPass skip_pushdown_pass.h
/// \brief This is a tree pass that will push down a skip node. It uses SkipNodes to first identify if we have a skip
/// node, and then based on the node types we observe in the tree, decide where to place the skip node (or use a
/// SequentialSampler for MappableSource nodes).
class SkipPushdownPass : public IRTreePass {
/// \class SkipNodes
/// \brief This is a NodePass whose job is to handle different nodes accordingly.
/// It works in conjunction with the SkipPushdownPass.
class SkipNodes : public IRNodePass {
public:
/// \brief Constructor
SkipNodes();
/// \brief Destructor
~SkipNodes() = default;
/// \brief Perform skip node pushdown initiation check on a SkipNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<SkipNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown completion check on a SkipNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a BatchNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<BatchNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a ProjectNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a RenameNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<RenameNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a MappableSourceNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a MapNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown check on a NonMappableSourceNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
/// \brief Perform skip node pushdown check on a MindDataNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override;
#endif
/// \brief Perform skip node pushdown check on a DatasetNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
/// \brief Perform skip node pushdown completion check on a DatasetNode
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override { return Status::OK(); };
/// \brief Getter
/// \return All the nodes where a skip node needs to be inserted above (and the skip count).
const std::vector<std::pair<std::shared_ptr<DatasetNode>, int64_t>> &insert_skip_above() const {
return insert_skip_above_;
}
/// \brief Getter
/// \return All the nodes to be removed
const std::vector<std::shared_ptr<DatasetNode>> &nodes_to_remove() const { return nodes_to_remove_; }
private:
std::vector<std::pair<std::shared_ptr<DatasetNode>, int64_t>> insert_skip_above_;
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_;
int64_t skip_count_;
};
public:
/// \brief Constructor
SkipPushdownPass();
/// \brief Destructor
~SkipPushdownPass() = default;
/// \brief Runs a skip_pushdown pass to push down the skip node found in the tree (for Reset scenario).
/// \param[in, out] tree The tree to operate on.
/// \param[in, out] Indicate of the tree was modified.
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_

View File

@ -36,6 +36,7 @@
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
namespace mindspore {
namespace dataset {
@ -57,8 +58,11 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
MS_LOG(INFO) << "Running pre pass loops.";
actions.emplace_back(std::make_unique<InputValidationPass>());
actions.emplace_back(std::make_unique<CacheValidationPass>());
if (usage_ == kDeReset) {
actions.emplace_back(std::make_unique<AddSkipPass>());
actions.emplace_back(std::make_unique<SkipPushdownPass>());
}
actions.emplace_back(std::make_unique<NodeRemovalPass>());
if (usage_ == kDeReset) actions.emplace_back(std::make_unique<AddSkipPass>());
actions.emplace_back(std::make_unique<EpochCtrlPass>());
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
#ifndef ENABLE_ANDROID

View File

@ -153,6 +153,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/random_sampler_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/sequential_sampler_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_random_sampler_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_sampler_ir.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.cc
@ -179,6 +180,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/opt/pre/node_removal_pass.cc
${MINDDATA_DIR}/engine/opt/pre/epoch_ctrl_pass.cc
${MINDDATA_DIR}/engine/opt/pre/deep_copy_pass.cc
${MINDDATA_DIR}/engine/opt/pre/skip_pushdown_pass.cc
${MINDDATA_DIR}/engine/opt/post/auto_worker_pass.cc
${MINDDATA_DIR}/engine/opt/pass.cc
${MINDDATA_DIR}/engine/perf/auto_tune.cc
@ -193,6 +195,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/datasetops/source/sampler/pk_sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/random_sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/sequential_sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/subset_random_sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/weighted_random_sampler.cc
${MINDDATA_DIR}/engine/runtime_context.cc

View File

@ -80,7 +80,6 @@ except ModuleNotFoundError:
if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork":
multiprocessing.set_start_method("fork", True)
OffloadToManualOffloadMode = {
None: cde.ManualOffloadMode.UNSPECIFIED,
False: cde.ManualOffloadMode.DISABLED,
@ -120,7 +119,7 @@ def _reset_training_dataset(step):
"""
dataset = _get_training_dataset()
if dataset is not None:
dataset.reset(step)
dataset._reset(step) # pylint: disable=W0212
else:
raise RuntimeError("Training dataset is not set.")
@ -3528,7 +3527,7 @@ class _ToDevice:
def send(self):
self._to_device.Send()
def reset(self, step):
def _reset(self, step):
self._to_device.Reset(step)
def stop_send(self):
@ -3638,10 +3637,10 @@ class TransferDataset(Dataset):
if self._to_device is not None:
self._to_device.continue_send()
def reset(self, step):
def _reset(self, step):
if self._to_device is not None:
logger.info("Reset the dataset pipeline to step " + str(step))
self._to_device.reset(step)
self._to_device._reset(step) # pylint: disable=W0212
def get_data_info(self):
"""

View File

@ -178,7 +178,7 @@ class Iterator:
self._getters()
return self._col_names
def reset(self, step):
def _reset(self, step):
"""
Reset the iterator to the given step number.

View File

@ -330,9 +330,9 @@ class DatasetHelper:
"""Continue to send data to device at the beginning of epoch."""
self.iter.continue_send()
def reset(self, step):
def _reset(self, step):
"""Reset the dataset to the provided step."""
self.iter.reset(step)
self.iter._reset(step) # pylint: disable=W0212
def get_data_info(self):
"""
@ -387,10 +387,11 @@ class _DatasetIter:
self.stop_send = dataset.__transfer_dataset__.stop_send
self.release = dataset.__transfer_dataset__.release
self.continue_send = dataset.__transfer_dataset__.continue_send
self.reset = dataset.__transfer_dataset__.reset
self.get_data_info = dataset.__transfer_dataset__.get_data_info
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
def __iter__(self):
self.index = 0

View File

@ -162,6 +162,7 @@ SET(DE_UT_SRCS
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
schema_test.cc
skip_pushdown_optimization_pass_test.cc
slice_op_test.cc
sliding_window_op_test.cc
solarize_op_test.cc

View File

@ -0,0 +1,676 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "common/common.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
#include "minddata/dataset/include/dataset/samplers.h"
#include "minddata/dataset/include/dataset/vision.h"
using namespace mindspore::dataset;
class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting {
protected:
MindDataSkipPushdownTestOptimizationPass() {}
/// \brief Compile and compare two datasets
/// \param[in] root_original Original dataset to be added the skip step
/// \param[in] root_target Target dataset for compare
/// \param[in] step Skip step
/// \return Status of the function
Status prepare_trees(std::shared_ptr<Dataset> root_original, std::shared_ptr<Dataset> root_target, int64_t step = 0) {
auto ir_tree = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
// Compile adds a new RootNode to the top of the tree
RETURN_IF_NOT_OK(ir_tree->Compile(root_original->IRNode(), 1, step));
auto ir_tree_target = std::make_shared<TreeAdapter>();
// Compile adds a new RootNode to the top of the tree
RETURN_IF_NOT_OK(ir_tree_target->Compile(root_target->IRNode(), 1,
0)); // Step is 0 for target node tree
if (step != 0) {
RETURN_IF_NOT_OK(compare_pass(ir_tree_target->RootIRNode(), ir_tree->RootIRNode()));
}
RETURN_IF_NOT_OK(compare_pass_row(ir_tree_target, ir_tree));
return Status::OK();
}
/// \brief Compare two dataset node trees
/// \param[in] expect Expected node tree for compare
/// \param[in] root Root node tree for compare
/// \return Status of the function
Status compare_pass(std::shared_ptr<DatasetNode> expect, std::shared_ptr<DatasetNode> root) {
if (expect->Children().size() == root->Children().size() && expect->Children().size() == 0) {
return Status::OK();
}
if (expect->Children().size() == root->Children().size() && expect->Children().size() != 0) {
for (int i = 0; i < expect->Children().size(); i++) {
std::string expect_name = expect->Children()[i]->Name();
std::string root_name = root->Children()[i]->Name();
CHECK_FAIL_RETURN_UNEXPECTED(expect_name == root_name,
"Expect child is " + expect_name + ", but got " + root_name);
RETURN_IF_NOT_OK(compare_pass(expect->Children()[i], root->Children()[i]));
}
} else {
return Status(StatusCode::kMDUnexpectedError, "Skip Optimization is not working as expected, expect to have " +
std::to_string(expect->Children().size()) +
" operation, but got " + std::to_string(root->Children().size()));
}
return Status::OK();
}
/// \brief Compare each row of two dataset node trees
/// \param[in] expect Expected tree for compare
/// \param[in] root Root tree for compare
/// \return Status of the function
Status compare_pass_row(std::shared_ptr<TreeAdapter> expect, std::shared_ptr<TreeAdapter> root) {
TensorRow row_expect;
TensorRow row_root;
RETURN_IF_NOT_OK(expect->GetNext(&row_expect));
RETURN_IF_NOT_OK(root->GetNext(&row_root));
while (row_expect.size() != 0 && row_root.size() != 0) {
std::vector<std::shared_ptr<Tensor>> e = row_expect.getRow();
std::vector<std::shared_ptr<Tensor>> r = row_root.getRow();
for (int i = 0; i < e.size(); i++) {
nlohmann::json out_json;
RETURN_IF_NOT_OK(e[i]->to_json(&out_json));
std::stringstream json_ss;
json_ss << out_json;
nlohmann::json out_json1;
RETURN_IF_NOT_OK(r[i]->to_json(&out_json1));
std::stringstream json_ss1;
json_ss1 << out_json1;
EXPECT_EQ(json_ss.str(), json_ss1.str());
}
RETURN_IF_NOT_OK(expect->GetNext(&row_expect));
RETURN_IF_NOT_OK(root->GetNext(&row_root));
}
EXPECT_EQ(row_expect.size(), row_root.size());
return Status::OK();
}
};
TensorRow VecToRow(const MSTensorVec &v);
MSTensorVec RowToVec(const TensorRow &v);
MSTensorVec Predicate1(MSTensorVec in);
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode
/// Expectation: Skip node is pushed down and removed after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownMappableSourceNode) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownMappableSourceNode.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
auto root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>());
auto root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(2));
EXPECT_OK(prepare_trees(root, root_target, 2));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation
/// Expectation: Skip node is pushed down and removed after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5)->Skip(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(25))->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(25))->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 5));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Rename Operation
/// Expectation: Skip node is pushed down and removed after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRename) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRename.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Rename({"label"}, {"fake_label"})->Skip(5);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(5))->Rename({"label"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Rename({"label"}, {"fake_label"});
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(5))->Rename({"label"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 5));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Project Operation
/// Expectation: Skip node is pushed down and removed for Project after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownProject) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownProject.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Skip(10);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(10))->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(10))->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 10));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Skip(1)
->Project({"label", "image"})
->Skip(10);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(11))->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 0));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Concat Operation
/// Expectation: Skip node cannot be pushed down for Concat after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownConcat) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownConcat.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<std::shared_ptr<Dataset>> datasets = {
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())};
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
EXPECT_OK(prepare_trees(root, root_target, 10));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Zip Operation
/// Expectation: Skip node cannot be pushed down for Zip after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownZip) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownZip.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<std::shared_ptr<Dataset>> datasets = {
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})};
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(10);
EXPECT_OK(prepare_trees(root, root_target, 10));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation
/// Expectation: Skip operation cannot be pushed down for Repeat after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(5)->Skip(11);
EXPECT_OK(prepare_trees(root, root_target, 11));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation
/// Expectation: Skip operation cannot be pushed down for Take after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20)->Skip(10);
EXPECT_OK(prepare_trees(root, root_target, 10));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Skip Operation
/// Expectation: Skip node cannot be pushed down after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(3);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(3)->Skip(5);
EXPECT_OK(prepare_trees(root, root_target, 5));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with NonMappableSourceNode(CSV)
/// Expectation: Skip node is pushed down after optimization pass, but cannot be removed
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode.";
std::string folder_path = datasets_root_path_ + "/testCSV/append.csv";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"});
root_target =
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"})->Skip(1);
root_target =
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 0));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Rename Operations
/// Expectation: Skip node is pushed down after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations1) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations1.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Batch(5)
->Skip(2)
->Rename({"label"}, {"fake_label"});
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Batch(5)
->Skip(2)
->Skip(2)
->Rename({"label"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 2));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Skip Operations and Sampler
/// Expectation: Skip node is pushed down after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations2) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations2.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Batch(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(10)->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 2));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Take and Project Operations
/// Expectation: Skip node is pushed down for Project but not for Take
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations3) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations3.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20)->Project({"label", "image"});
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Take(20)
->Skip(2)
->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 2));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with no Skip/ Skip(0) Operation
/// Expectation: Skip(0) shows the same result as no Skip operation
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip0) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip0.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Take(5);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Take(5);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Skip(0);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Skip(0)
->Project({"label", "image"})
->Skip(0);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(0)->Project({"label", "image"});
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Skip(2)
->Skip(1)
->Project({"label", "image"});
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Skip(2)
->Skip(1)
->Skip(1)
->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation
/// Expectation: Skip operation cannot be pushed down for Repeat operation
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat2) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat2.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(1);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(50);
EXPECT_OK(prepare_trees(root, root_target, 50));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Repeat(3);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Repeat(3)->Skip(50);
EXPECT_OK(prepare_trees(root, root_target, 50));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation
/// Expectation: Skip operation cannot be pushed down for Take operation
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake2) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake2.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(3);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(3);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(3);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(3)->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 1));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(5)->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Concat/Zip Operation
/// Expectation: Skip node cannot be removed for Concat/Zip after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<std::shared_ptr<Dataset>> datasets = {
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})};
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})->Concat(datasets);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Project({"label"})
->Concat(datasets)
->Skip(2);
EXPECT_OK(prepare_trees(root, root_target, 2));
root =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Filter Operation
/// Expectation: Skip node cannot be pushed down for Filter after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported_Filter) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported_Filter.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Filter(Predicate1, {"label"})->Skip(1);
root_target =
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Filter(Predicate1, {"label"})->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
->Skip(1)
->Filter(Predicate1, {"label"})
->Skip(1);
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with SubsetSampler as child
/// Expectation: Skip node is removed for Rename/Project/Map after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSubsetSampler) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSubsetSampler.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<int64_t> indices = {0, 1, 2, 3, 4, 5};
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 3))->Skip(1);
auto sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 3));
root_target = ImageFolder(folder_path, false, sampler);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 4))
->Rename({"label"}, {"fake_label"});
sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 4));
root_target = ImageFolder(folder_path, false, sampler)->Rename({"label"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root =
ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 10))->Project({"label", "image"});
sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 10));
root_target = ImageFolder(folder_path, false, sampler)->Project({"label", "image"});
EXPECT_OK(prepare_trees(root, root_target, 1));
std::vector<std::shared_ptr<TensorTransform>> transforms;
std::vector<int32_t> size = {80, 80};
std::vector<uint32_t> ignore = {20, 20, 20, 20};
std::shared_ptr<TensorTransform> operation1 = std::make_shared<vision::AutoContrast>(0.5, ignore);
std::shared_ptr<TensorTransform> operation2 = std::make_shared<vision::CenterCrop>(size);
transforms.push_back(operation1);
transforms.push_back(operation2);
root = ImageFolder(folder_path, true, std::make_shared<SubsetRandomSampler>(indices, 3))->Map(transforms);
sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 3));
root_target = ImageFolder(folder_path, true, sampler)->Map(transforms);
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation
/// Expectation: Skip node is pushed down for Batch after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch2) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch2.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Batch(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 0));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(3)->Batch(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(3)->Skip(15)->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 3));
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5);
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(15))->Batch(5);
EXPECT_OK(prepare_trees(root, root_target, 3));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with SkipPushdownNonMappableSourceNode(CSV)
/// Expectation: Skip node is pushed down for Rename/Batch/Project after optimization pass, but cannot be removed for
/// NonMappableSourceNode(CSV)
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode2) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode2.";
std::string folder_path = datasets_root_path_ + "/testCSV/append.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"});
root_target =
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
->Skip(1)
->Skip(1)
->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
->Repeat(5)
->Skip(1)
->Batch(2)
->Project({"col1", "col2"})
->Rename({"col1"}, {"fake_label"});
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
->Repeat(5)
->Skip(1)
->Skip(2)
->Batch(2)
->Project({"col1", "col2"})
->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
}
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with combined Operations
/// Expectation: Skip node is removed/reduced for Rename/Batch after optimization pass
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations4) {
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations4.";
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> root;
std::shared_ptr<Dataset> root_target;
std::vector<int64_t> indices = {0, 1, 2, 3, 4, 5};
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 6))
->Repeat(4)
->Skip(4)
->Take(10)
->Skip(2)
->Take(2)
->Rename({"label"}, {"fake_label"});
auto sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 6));
root_target = ImageFolder(folder_path, false, sampler)
->Repeat(4)
->Skip(4)
->Take(10)
->Skip(2)
->Take(2)
->Skip(1)
->Rename({"label"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 1));
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 5))
->Repeat(8)
->Skip(4)
->Skip(2)
->Take(10)
->Rename({"label"}, {"fake_label"})
->Batch(4);
sampler = std::make_shared<SequentialSampler>(1);
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 5));
root_target = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 5))
->Repeat(8)
->Skip(4)
->Skip(2)
->Take(10)
->Skip(8)
->Rename({"label"}, {"fake_label"})
->Batch(4);
EXPECT_OK(prepare_trees(root, root_target, 2));
folder_path = datasets_root_path_ + "/testCSV/append.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
->Skip(1)
->Repeat(4)
->Batch(3)
->Skip(1)
->Rename({"col1"}, {"fake_label"});
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
->Skip(1)
->Repeat(4)
->Skip(3)
->Batch(3)
->Rename({"col1"}, {"fake_label"});
EXPECT_OK(prepare_trees(root, root_target, 0));
}